From 31b7e14052b547a5ddf97c07840a830ec27524b7 Mon Sep 17 00:00:00 2001 From: Huanchen Zhai <44282896+hczhai@users.noreply.github.com> Date: Sat, 13 Jan 2024 11:07:36 -0800 Subject: [PATCH] bugfix: removing typing and duplicate ``class_`` for KeysView/ValuesView/ItemsView. Fix #4529 (#4985) * remove typing for KeysView/ValuesView/ItemsView * add tests for map view types --- include/pybind11/stl_bind.h | 79 ++++++++++++------------------------- tests/test_stl_binders.cpp | 10 +++++ tests/test_stl_binders.py | 30 ++++++++++++-- 3 files changed, 62 insertions(+), 57 deletions(-) diff --git a/include/pybind11/stl_bind.h b/include/pybind11/stl_bind.h index 4d0854f6a8..a226cbc0e8 100644 --- a/include/pybind11/stl_bind.h +++ b/include/pybind11/stl_bind.h @@ -645,49 +645,50 @@ auto map_if_insertion_operator(Class_ &cl, std::string const &name) "Return the canonical string representation of this map."); } -template struct keys_view { virtual size_t len() = 0; virtual iterator iter() = 0; - virtual bool contains(const KeyType &k) = 0; - virtual bool contains(const object &k) = 0; + virtual bool contains(const handle &k) = 0; virtual ~keys_view() = default; }; -template struct values_view { virtual size_t len() = 0; virtual iterator iter() = 0; virtual ~values_view() = default; }; -template struct items_view { virtual size_t len() = 0; virtual iterator iter() = 0; virtual ~items_view() = default; }; -template -struct KeysViewImpl : public KeysView { +template +struct KeysViewImpl : public detail::keys_view { explicit KeysViewImpl(Map &map) : map(map) {} size_t len() override { return map.size(); } iterator iter() override { return make_key_iterator(map.begin(), map.end()); } - bool contains(const typename Map::key_type &k) override { return map.find(k) != map.end(); } - bool contains(const object &) override { return false; } + bool contains(const handle &k) override { + try { + return map.find(k.template cast()) != map.end(); + } catch (const cast_error &) { + return false; + } + } Map ↦ }; -template -struct ValuesViewImpl : public ValuesView { +template +struct ValuesViewImpl : public detail::values_view { explicit ValuesViewImpl(Map &map) : map(map) {} size_t len() override { return map.size(); } iterator iter() override { return make_value_iterator(map.begin(), map.end()); } Map ↦ }; -template -struct ItemsViewImpl : public ItemsView { +template +struct ItemsViewImpl : public detail::items_view { explicit ItemsViewImpl(Map &map) : map(map) {} size_t len() override { return map.size(); } iterator iter() override { return make_iterator(map.begin(), map.end()); } @@ -700,11 +701,9 @@ template , typename... class_ bind_map(handle scope, const std::string &name, Args &&...args) { using KeyType = typename Map::key_type; using MappedType = typename Map::mapped_type; - using StrippedKeyType = detail::remove_cvref_t; - using StrippedMappedType = detail::remove_cvref_t; - using KeysView = detail::keys_view; - using ValuesView = detail::values_view; - using ItemsView = detail::items_view; + using KeysView = detail::keys_view; + using ValuesView = detail::values_view; + using ItemsView = detail::items_view; using Class_ = class_; // If either type is a non-module-local bound type then make the map binding non-local as well; @@ -718,39 +717,20 @@ class_ bind_map(handle scope, const std::string &name, Args && } Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward(args)...); - static constexpr auto key_type_descr = detail::make_caster::name; - static constexpr auto mapped_type_descr = detail::make_caster::name; - std::string key_type_name(key_type_descr.text), mapped_type_name(mapped_type_descr.text); - // If key type isn't properly wrapped, fall back to C++ names - if (key_type_name == "%") { - key_type_name = detail::type_info_description(typeid(KeyType)); - } - // Similarly for value type: - if (mapped_type_name == "%") { - mapped_type_name = detail::type_info_description(typeid(MappedType)); - } - - // Wrap KeysView[KeyType] if it wasn't already wrapped + // Wrap KeysView if it wasn't already wrapped if (!detail::get_type_info(typeid(KeysView))) { - class_ keys_view( - scope, ("KeysView[" + key_type_name + "]").c_str(), pybind11::module_local(local)); + class_ keys_view(scope, "KeysView", pybind11::module_local(local)); keys_view.def("__len__", &KeysView::len); keys_view.def("__iter__", &KeysView::iter, keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */ ); - keys_view.def("__contains__", - static_cast(&KeysView::contains)); - // Fallback for when the object is not of the key type - keys_view.def("__contains__", - static_cast(&KeysView::contains)); + keys_view.def("__contains__", &KeysView::contains); } // Similarly for ValuesView: if (!detail::get_type_info(typeid(ValuesView))) { - class_ values_view(scope, - ("ValuesView[" + mapped_type_name + "]").c_str(), - pybind11::module_local(local)); + class_ values_view(scope, "ValuesView", pybind11::module_local(local)); values_view.def("__len__", &ValuesView::len); values_view.def("__iter__", &ValuesView::iter, @@ -759,10 +739,7 @@ class_ bind_map(handle scope, const std::string &name, Args && } // Similarly for ItemsView: if (!detail::get_type_info(typeid(ItemsView))) { - class_ items_view( - scope, - ("ItemsView[" + key_type_name + ", ").append(mapped_type_name + "]").c_str(), - pybind11::module_local(local)); + class_ items_view(scope, "ItemsView", pybind11::module_local(local)); items_view.def("__len__", &ItemsView::len); items_view.def("__iter__", &ItemsView::iter, @@ -788,25 +765,19 @@ class_ bind_map(handle scope, const std::string &name, Args && cl.def( "keys", - [](Map &m) { - return std::unique_ptr(new detail::KeysViewImpl(m)); - }, + [](Map &m) { return std::unique_ptr(new detail::KeysViewImpl(m)); }, keep_alive<0, 1>() /* Essential: keep map alive while view exists */ ); cl.def( "values", - [](Map &m) { - return std::unique_ptr(new detail::ValuesViewImpl(m)); - }, + [](Map &m) { return std::unique_ptr(new detail::ValuesViewImpl(m)); }, keep_alive<0, 1>() /* Essential: keep map alive while view exists */ ); cl.def( "items", - [](Map &m) { - return std::unique_ptr(new detail::ItemsViewImpl(m)); - }, + [](Map &m) { return std::unique_ptr(new detail::ItemsViewImpl(m)); }, keep_alive<0, 1>() /* Essential: keep map alive while view exists */ ); diff --git a/tests/test_stl_binders.cpp b/tests/test_stl_binders.cpp index e52a03b6d2..96af9a4c4b 100644 --- a/tests/test_stl_binders.cpp +++ b/tests/test_stl_binders.cpp @@ -192,6 +192,16 @@ TEST_SUBMODULE(stl_binders, m) { py::bind_map>(m, "UnorderedMapStringDoubleConst"); + // test_map_view_types + py::bind_map>(m, "MapStringFloat"); + py::bind_map>(m, "UnorderedMapStringFloat"); + + py::bind_map, int32_t>>(m, "MapPairDoubleIntInt32"); + py::bind_map, int64_t>>(m, "MapPairDoubleIntInt64"); + + py::bind_map>(m, "MapIntObject"); + py::bind_map>(m, "MapStringObject"); + py::class_(m, "ENC").def(py::init()).def_readwrite("value", &E_nc::value); // test_noncopyable_containers diff --git a/tests/test_stl_binders.py b/tests/test_stl_binders.py index c00d45b926..d1bf64aa02 100644 --- a/tests/test_stl_binders.py +++ b/tests/test_stl_binders.py @@ -317,9 +317,9 @@ def test_map_view_types(): map_string_double_const = m.MapStringDoubleConst() unordered_map_string_double_const = m.UnorderedMapStringDoubleConst() - assert map_string_double.keys().__class__.__name__ == "KeysView[str]" - assert map_string_double.values().__class__.__name__ == "ValuesView[float]" - assert map_string_double.items().__class__.__name__ == "ItemsView[str, float]" + assert map_string_double.keys().__class__.__name__ == "KeysView" + assert map_string_double.values().__class__.__name__ == "ValuesView" + assert map_string_double.items().__class__.__name__ == "ItemsView" keys_type = type(map_string_double.keys()) assert type(unordered_map_string_double.keys()) is keys_type @@ -336,6 +336,30 @@ def test_map_view_types(): assert type(map_string_double_const.items()) is items_type assert type(unordered_map_string_double_const.items()) is items_type + map_string_float = m.MapStringFloat() + unordered_map_string_float = m.UnorderedMapStringFloat() + + assert type(map_string_float.keys()) is keys_type + assert type(unordered_map_string_float.keys()) is keys_type + assert type(map_string_float.values()) is values_type + assert type(unordered_map_string_float.values()) is values_type + assert type(map_string_float.items()) is items_type + assert type(unordered_map_string_float.items()) is items_type + + map_pair_double_int_int32 = m.MapPairDoubleIntInt32() + map_pair_double_int_int64 = m.MapPairDoubleIntInt64() + + assert type(map_pair_double_int_int32.values()) is values_type + assert type(map_pair_double_int_int64.values()) is values_type + + map_int_object = m.MapIntObject() + map_string_object = m.MapStringObject() + + assert type(map_int_object.keys()) is keys_type + assert type(map_string_object.keys()) is keys_type + assert type(map_int_object.items()) is items_type + assert type(map_string_object.items()) is items_type + def test_recursive_vector(): recursive_vector = m.RecursiveVector()