Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate py::multiple_inheritance to all children #3650

Merged
merged 2 commits into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -1206,10 +1206,13 @@ class generic_type : public object {
if (rec.bases.size() > 1 || rec.multiple_inheritance) {
mark_parents_nonsimple(tinfo->type);
tinfo->simple_ancestors = false;
tinfo->simple_type = false;
}
else if (rec.bases.size() == 1) {
auto parent_tinfo = get_type_info((PyTypeObject *) rec.bases[0].ptr());
tinfo->simple_ancestors = parent_tinfo->simple_ancestors;
// a child of a non-simple type can never be a simple type
tinfo->simple_type = parent_tinfo->simple_type;
}

Copy link
Collaborator

@Skylion007 Skylion007 Jan 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the documentation and our testing for the TypeInfo type, a type with non simple_ancestors can never be a simple type. We should add an assert somewhere that verifies this:

Suggested change
// a simple type can never have non-simple ancestors
assert(!tinfo->simple_type || tinfo->simple_ancestors));

if (rec.module_local) {
Expand Down
83 changes: 83 additions & 0 deletions tests/test_multiple_inheritance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,87 @@ TEST_SUBMODULE(multiple_inheritance, m) {
.def("c1", [](C1 *self) { return self; });
py::class_<D, C0, C1>(m, "D")
.def(py::init<>());

// test_pr3635_diamond_*
// - functions are get_{base}_{var}, return {var}
struct MVB {
MVB() = default;
MVB(const MVB&) = default;
virtual ~MVB() = default;

int b = 1;
int get_b_b() const { return b; }
};
struct MVC : virtual MVB {
int c = 2;
int get_c_b() const { return b; }
int get_c_c() const { return c; }
};
struct MVD0 : virtual MVC {
int d0 = 3;
int get_d0_b() const { return b; }
int get_d0_c() const { return c; }
int get_d0_d0() const { return d0; }
};
struct MVD1 : virtual MVC {
int d1 = 4;
int get_d1_b() const { return b; }
int get_d1_c() const { return c; }
int get_d1_d1() const { return d1; }
};
struct MVE : virtual MVD0, virtual MVD1 {
int e = 5;
int get_e_b() const { return b; }
int get_e_c() const { return c; }
int get_e_d0() const { return d0; }
int get_e_d1() const { return d1; }
int get_e_e() const { return e; }
};
struct MVF : virtual MVE {
int f = 6;
int get_f_b() const { return b; }
int get_f_c() const { return c; }
int get_f_d0() const { return d0; }
int get_f_d1() const { return d1; }
int get_f_e() const { return e; }
int get_f_f() const { return f; }
};
py::class_<MVB>(m, "MVB")
.def(py::init<>())
.def("get_b_b", &MVB::get_b_b)
.def_readwrite("b", &MVB::b);
py::class_<MVC, MVB>(m, "MVC")
.def(py::init<>())
.def("get_c_b", &MVC::get_c_b)
.def("get_c_c", &MVC::get_c_c)
.def_readwrite("c", &MVC::c);
py::class_<MVD0, MVC>(m, "MVD0")
.def(py::init<>())
.def("get_d0_b", &MVD0::get_d0_b)
.def("get_d0_c", &MVD0::get_d0_c)
.def("get_d0_d0", &MVD0::get_d0_d0)
.def_readwrite("d0", &MVD0::d0);
py::class_<MVD1, MVC>(m, "MVD1")
.def(py::init<>())
.def("get_d1_b", &MVD1::get_d1_b)
.def("get_d1_c", &MVD1::get_d1_c)
.def("get_d1_d1", &MVD1::get_d1_d1)
.def_readwrite("d1", &MVD1::d1);
py::class_<MVE, MVD0, MVD1>(m, "MVE")
.def(py::init<>())
.def("get_e_b", &MVE::get_e_b)
.def("get_e_c", &MVE::get_e_c)
.def("get_e_d0", &MVE::get_e_d0)
.def("get_e_d1", &MVE::get_e_d1)
.def("get_e_e", &MVE::get_e_e)
.def_readwrite("e", &MVE::e);
py::class_<MVF, MVE>(m, "MVF")
.def(py::init<>())
.def("get_f_b", &MVF::get_f_b)
.def("get_f_c", &MVF::get_f_c)
.def("get_f_d0", &MVF::get_f_d0)
.def("get_f_d1", &MVF::get_f_d1)
.def("get_f_e", &MVF::get_f_e)
.def("get_f_f", &MVF::get_f_f)
.def_readwrite("f", &MVF::f);
}
114 changes: 114 additions & 0 deletions tests/test_multiple_inheritance.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,117 @@ def test_diamond_inheritance():
assert d is d.c0().b()
assert d is d.c1().b()
assert d is d.c0().c1().b().c0().b()


def test_pr3635_diamond_b():
o = m.MVB()
assert o.b == 1

assert o.get_b_b() == 1


def test_pr3635_diamond_c():
o = m.MVC()
assert o.b == 1
assert o.c == 2

assert o.get_b_b() == 1
assert o.get_c_b() == 1

assert o.get_c_c() == 2


def test_pr3635_diamond_d0():
o = m.MVD0()
assert o.b == 1
assert o.c == 2
assert o.d0 == 3

assert o.get_b_b() == 1
assert o.get_c_b() == 1
assert o.get_d0_b() == 1

assert o.get_c_c() == 2
assert o.get_d0_c() == 2

assert o.get_d0_d0() == 3


def test_pr3635_diamond_d1():
o = m.MVD1()
assert o.b == 1
assert o.c == 2
assert o.d1 == 4

assert o.get_b_b() == 1
assert o.get_c_b() == 1
assert o.get_d1_b() == 1

assert o.get_c_c() == 2
assert o.get_d1_c() == 2

assert o.get_d1_d1() == 4


def test_pr3635_diamond_e():
o = m.MVE()
assert o.b == 1
assert o.c == 2
assert o.d0 == 3
assert o.d1 == 4
assert o.e == 5

assert o.get_b_b() == 1
assert o.get_c_b() == 1
assert o.get_d0_b() == 1
assert o.get_d1_b() == 1
assert o.get_e_b() == 1

assert o.get_c_c() == 2
assert o.get_d0_c() == 2
assert o.get_d1_c() == 2
assert o.get_e_c() == 2

assert o.get_d0_d0() == 3
assert o.get_e_d0() == 3

assert o.get_d1_d1() == 4
assert o.get_e_d1() == 4

assert o.get_e_e() == 5


def test_pr3635_diamond_f():
o = m.MVF()
assert o.b == 1
assert o.c == 2
assert o.d0 == 3
assert o.d1 == 4
assert o.e == 5
assert o.f == 6

assert o.get_b_b() == 1
assert o.get_c_b() == 1
assert o.get_d0_b() == 1
assert o.get_d1_b() == 1
assert o.get_e_b() == 1
assert o.get_f_b() == 1

assert o.get_c_c() == 2
assert o.get_d0_c() == 2
assert o.get_d1_c() == 2
assert o.get_e_c() == 2
assert o.get_f_c() == 2

assert o.get_d0_d0() == 3
assert o.get_e_d0() == 3
assert o.get_f_d0() == 3

assert o.get_d1_d1() == 4
assert o.get_e_d1() == 4
assert o.get_f_d1() == 4

assert o.get_e_e() == 5
assert o.get_f_e() == 5

assert o.get_f_f() == 6