Skip to content

Commit

Permalink
perf: Add object rvalue overload for accessors. Enables reference ste…
Browse files Browse the repository at this point in the history
…aling (#3970)

* Add object rvalue overload for accessors. Enables reference stealing

* Fix comments

* Fix more comment typos

* Fix bug

* reorder declarations for clarity

* fix another perf bug

* should be static

* future proof operator overloads

* Fix perfect forwarding

* Add a couple of tests

* Remove errant include

* Improve test documentation

* Add dict test

* add object attr tests

* Optimize STL map caster and cleanup enum

* Reorder to match declarations

* adjust increfs

* Remove comment

* revert value change

* add missing move
  • Loading branch information
Skylion007 authored Jun 1, 2022
1 parent 9f7b3f7 commit 58802de
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 12 deletions.
6 changes: 3 additions & 3 deletions include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -2069,12 +2069,12 @@ struct enum_base {
str name(name_);
if (entries.contains(name)) {
std::string type_name = (std::string) str(m_base.attr("__name__"));
throw value_error(type_name + ": element \"" + std::string(name_)
throw value_error(std::move(type_name) + ": element \"" + std::string(name_)
+ "\" already exists!");
}

entries[name] = std::make_pair(value, doc);
m_base.attr(name) = value;
m_base.attr(std::move(name)) = std::move(value);
}

PYBIND11_NOINLINE void export_values() {
Expand Down Expand Up @@ -2610,7 +2610,7 @@ PYBIND11_NOINLINE void print(const tuple &args, const dict &kwargs) {
}

auto write = file.attr("write");
write(line);
write(std::move(line));
write(kwargs.contains("end") ? kwargs["end"] : str("\n"));

if (kwargs.contains("flush") && kwargs["flush"].cast<bool>()) {
Expand Down
36 changes: 30 additions & 6 deletions include/pybind11/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ class object_api : public pyobject_tag {
or `object` subclass causes a call to ``__setitem__``.
\endrst */
item_accessor operator[](handle key) const;
/// See above (the only difference is that they key is provided as a string literal)
/// See above (the only difference is that the key's reference is stolen)
item_accessor operator[](object &&key) const;
/// See above (the only difference is that the key is provided as a string literal)
item_accessor operator[](const char *key) const;

/** \rst
Expand All @@ -95,7 +97,9 @@ class object_api : public pyobject_tag {
or `object` subclass causes a call to ``setattr``.
\endrst */
obj_attr_accessor attr(handle key) const;
/// See above (the only difference is that they key is provided as a string literal)
/// See above (the only difference is that the key's reference is stolen)
obj_attr_accessor attr(object &&key) const;
/// See above (the only difference is that the key is provided as a string literal)
str_attr_accessor attr(const char *key) const;

/** \rst
Expand Down Expand Up @@ -684,7 +688,7 @@ class accessor : public object_api<accessor<Policy>> {
}
template <typename T>
void operator=(T &&value) & {
get_cache() = reinterpret_borrow<object>(object_or_cast(std::forward<T>(value)));
get_cache() = ensure_object(object_or_cast(std::forward<T>(value)));
}

template <typename T = Policy>
Expand Down Expand Up @@ -712,6 +716,9 @@ class accessor : public object_api<accessor<Policy>> {
}

private:
static object ensure_object(object &&o) { return std::move(o); }
static object ensure_object(handle h) { return reinterpret_borrow<object>(h); }

object &get_cache() const {
if (!cache) {
cache = Policy::get(obj, key);
Expand Down Expand Up @@ -1711,7 +1718,10 @@ class tuple : public object {
size_t size() const { return (size_t) PyTuple_Size(m_ptr); }
bool empty() const { return size() == 0; }
detail::tuple_accessor operator[](size_t index) const { return {*this, index}; }
detail::item_accessor operator[](handle h) const { return object::operator[](h); }
template <typename T, detail::enable_if_t<detail::is_pyobject<T>::value, int> = 0>
detail::item_accessor operator[](T &&o) const {
return object::operator[](std::forward<T>(o));
}
detail::tuple_iterator begin() const { return {*this, 0}; }
detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; }
};
Expand Down Expand Up @@ -1771,7 +1781,10 @@ class sequence : public object {
}
bool empty() const { return size() == 0; }
detail::sequence_accessor operator[](size_t index) const { return {*this, index}; }
detail::item_accessor operator[](handle h) const { return object::operator[](h); }
template <typename T, detail::enable_if_t<detail::is_pyobject<T>::value, int> = 0>
detail::item_accessor operator[](T &&o) const {
return object::operator[](std::forward<T>(o));
}
detail::sequence_iterator begin() const { return {*this, 0}; }
detail::sequence_iterator end() const { return {*this, PySequence_Size(m_ptr)}; }
};
Expand All @@ -1790,7 +1803,10 @@ class list : public object {
size_t size() const { return (size_t) PyList_Size(m_ptr); }
bool empty() const { return size() == 0; }
detail::list_accessor operator[](size_t index) const { return {*this, index}; }
detail::item_accessor operator[](handle h) const { return object::operator[](h); }
template <typename T, detail::enable_if_t<detail::is_pyobject<T>::value, int> = 0>
detail::item_accessor operator[](T &&o) const {
return object::operator[](std::forward<T>(o));
}
detail::list_iterator begin() const { return {*this, 0}; }
detail::list_iterator end() const { return {*this, PyList_GET_SIZE(m_ptr)}; }
template <typename T>
Expand Down Expand Up @@ -2090,6 +2106,10 @@ item_accessor object_api<D>::operator[](handle key) const {
return {derived(), reinterpret_borrow<object>(key)};
}
template <typename D>
item_accessor object_api<D>::operator[](object &&key) const {
return {derived(), std::move(key)};
}
template <typename D>
item_accessor object_api<D>::operator[](const char *key) const {
return {derived(), pybind11::str(key)};
}
Expand All @@ -2098,6 +2118,10 @@ obj_attr_accessor object_api<D>::attr(handle key) const {
return {derived(), reinterpret_borrow<object>(key)};
}
template <typename D>
obj_attr_accessor object_api<D>::attr(object &&key) const {
return {derived(), std::move(key)};
}
template <typename D>
str_attr_accessor object_api<D>::attr(const char *key) const {
return {derived(), key};
}
Expand Down
2 changes: 1 addition & 1 deletion include/pybind11/stl.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ struct map_caster {
if (!key || !value) {
return handle();
}
d[key] = value;
d[std::move(key)] = std::move(value);
}
return d.release();
}
Expand Down
34 changes: 34 additions & 0 deletions tests/test_pytypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,4 +661,38 @@ TEST_SUBMODULE(pytypes, m) {
double v = x.get_value();
return v * v;
});

m.def("tuple_rvalue_getter", [](const py::tuple &tup) {
// tests accessing tuple object with rvalue int
for (size_t i = 0; i < tup.size(); i++) {
auto o = py::handle(tup[py::int_(i)]);
if (!o) {
throw py::value_error("tuple is malformed");
}
}
return tup;
});
m.def("list_rvalue_getter", [](const py::list &l) {
// tests accessing list with rvalue int
for (size_t i = 0; i < l.size(); i++) {
auto o = py::handle(l[py::int_(i)]);
if (!o) {
throw py::value_error("list is malformed");
}
}
return l;
});
m.def("populate_dict_rvalue", [](int population) {
auto d = py::dict();
for (int i = 0; i < population; i++) {
d[py::int_(i)] = py::int_(i);
}
return d;
});
m.def("populate_obj_str_attrs", [](py::object &o, int population) {
for (int i = 0; i < population; i++) {
o.attr(py::str(py::int_(i))) = py::str(py::int_(i));
}
return o;
});
}
31 changes: 29 additions & 2 deletions tests/test_pytypes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import sys
import types

import pytest

Expand Down Expand Up @@ -320,8 +321,7 @@ def func(self, x, *args):
def test_accessor_moves():
inc_refs = m.accessor_moves()
if inc_refs:
# To be changed in PR #3970: [1, 0, 1, 0, ...]
assert inc_refs == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
assert inc_refs == [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
else:
pytest.skip("Not defined: PYBIND11_HANDLE_REF_DEBUG")

Expand Down Expand Up @@ -707,3 +707,30 @@ def test_implementation_details():
def test_external_float_():
r1 = m.square_float_(2.0)
assert r1 == 4.0


def test_tuple_rvalue_getter():
pop = 1000
tup = tuple(range(pop))
m.tuple_rvalue_getter(tup)


def test_list_rvalue_getter():
pop = 1000
my_list = list(range(pop))
m.list_rvalue_getter(my_list)


def test_populate_dict_rvalue():
pop = 1000
my_dict = {i: i for i in range(pop)}
assert m.populate_dict_rvalue(pop) == my_dict


def test_populate_obj_str_attrs():
pop = 1000
o = types.SimpleNamespace(**{str(i): i for i in range(pop)})
new_o = m.populate_obj_str_attrs(o, pop)
new_attrs = {k: v for k, v in new_o.__dict__.items() if not k.startswith("_")}
assert all(isinstance(v, str) for v in new_attrs.values())
assert len(new_attrs) == pop

0 comments on commit 58802de

Please sign in to comment.