diff --git a/include/registry.h b/include/registry.h index ff1f636f..f24501c9 100644 --- a/include/registry.h +++ b/include/registry.h @@ -24,6 +24,8 @@ limitations under the License. #include // std::unordered_map #include // std::pair +#include "include/utils.h" + namespace optree { namespace py = pybind11; @@ -92,40 +94,6 @@ class PyTreeTypeRegistry { const py::function &unflatten_func, const std::string ®istry_namespace); - class TypeHash { - public: - using is_transparent = void; - size_t operator()(const py::object &t) const; - size_t operator()(const py::handle &t) const; - }; - class TypeEq { - public: - using is_transparent = void; - bool operator()(const py::object &a, const py::object &b) const; - bool operator()(const py::object &a, const py::handle &b) const; - bool operator()(const py::handle &a, const py::object &b) const; - bool operator()(const py::handle &a, const py::handle &b) const; - }; - - class NamedTypeHash { - public: - using is_transparent = void; - size_t operator()(const std::pair &p) const; - size_t operator()(const std::pair &p) const; - }; - class NamedTypeEq { - public: - using is_transparent = void; - bool operator()(const std::pair &a, - const std::pair &b) const; - bool operator()(const std::pair &a, - const std::pair &b) const; - bool operator()(const std::pair &a, - const std::pair &b) const; - bool operator()(const std::pair &a, - const std::pair &b) const; - }; - std::unordered_map, TypeHash, TypeEq> m_registrations{}; std::unordered_map, diff --git a/include/utils.h b/include/utils.h index 4b46dbdb..cd239876 100644 --- a/include/utils.h +++ b/include/utils.h @@ -25,12 +25,13 @@ limitations under the License. #include -#include // std::rethrow_exception, std::current_exception -#include // std::hash -#include // std::ostringstream -#include // std::string -#include // std::move, std::pair, std::make_pair -#include // std::vector +#include // std::rethrow_exception, std::current_exception +#include // std::hash +#include // std::ostringstream +#include // std::string +#include // std::unordered_map +#include // std::move, std::pair, std::make_pair +#include // std::vector namespace py = pybind11; using size_t = py::size_t; @@ -50,6 +51,58 @@ inline void HashCombine(py::ssize_t& seed, const T& v) { // NOLINT[runtime/refe seed ^= (hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2)); } +class TypeHash { + public: + using is_transparent = void; + size_t operator()(const py::object& t) const { return std::hash{}(t.ptr()); } + size_t operator()(const py::handle& t) const { return std::hash{}(t.ptr()); } +}; +class TypeEq { + public: + using is_transparent = void; + bool operator()(const py::object& a, const py::object& b) const { return a.ptr() == b.ptr(); } + bool operator()(const py::object& a, const py::handle& b) const { return a.ptr() == b.ptr(); } + bool operator()(const py::handle& a, const py::object& b) const { return a.ptr() == b.ptr(); } + bool operator()(const py::handle& a, const py::handle& b) const { return a.ptr() == b.ptr(); } +}; + +class NamedTypeHash { + public: + using is_transparent = void; + size_t operator()(const std::pair& p) const { + size_t seed = 0; + HashCombine(seed, p.first); + HashCombine(seed, p.second.ptr()); + return seed; + } + size_t operator()(const std::pair& p) const { + size_t seed = 0; + HashCombine(seed, p.first); + HashCombine(seed, p.second.ptr()); + return seed; + } +}; +class NamedTypeEq { + public: + using is_transparent = void; + bool operator()(const std::pair& a, + const std::pair& b) const { + return a.first == b.first && a.second.ptr() == b.second.ptr(); + } + bool operator()(const std::pair& a, + const std::pair& b) const { + return a.first == b.first && a.second.ptr() == b.second.ptr(); + } + bool operator()(const std::pair& a, + const std::pair& b) const { + return a.first == b.first && a.second.ptr() == b.second.ptr(); + } + bool operator()(const std::pair& a, + const std::pair& b) const { + return a.first == b.first && a.second.ptr() == b.second.ptr(); + } +}; + constexpr bool NONE_IS_LEAF = true; constexpr bool NONE_IS_NODE = false; @@ -395,7 +448,22 @@ inline bool IsNamedTupleClassImpl(const py::handle& type) { return false; } inline bool IsNamedTupleClass(const py::handle& type) { - return PyType_Check(type.ptr()) && IsNamedTupleClassImpl(type); + if (PyType_Check(type.ptr())) [[likely]] { + static auto cache = std::unordered_map{}; + auto it = cache.find(type); + if (it != cache.end()) [[likely]] { + return it->second; + } + bool result = IsNamedTupleClassImpl(type); + cache.emplace(type, result); + (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { + cache.erase(type); + weakref.dec_ref(); + })) + .release(); + return result; + } + return false; } inline bool IsNamedTupleInstance(const py::handle& object) { return IsNamedTupleClass(py::type::handle_of(object)); @@ -458,7 +526,22 @@ inline bool IsStructSequenceClassImpl(const py::handle& type) { return false; } inline bool IsStructSequenceClass(const py::handle& type) { - return PyType_Check(type.ptr()) && IsStructSequenceClassImpl(type); + if (PyType_Check(type.ptr())) [[likely]] { + static auto cache = std::unordered_map{}; + auto it = cache.find(type); + if (it != cache.end()) [[likely]] { + return it->second; + } + bool result = IsStructSequenceClassImpl(type); + cache.emplace(type, result); + (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { + cache.erase(type); + weakref.dec_ref(); + })) + .release(); + return result; + } + return false; } inline bool IsStructSequenceInstance(const py::handle& object) { return IsStructSequenceClass(py::type::handle_of(object)); diff --git a/src/registry.cpp b/src/registry.cpp index 6c098322..4fdf650f 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -168,60 +168,4 @@ template const PyTreeTypeRegistry::Registration* PyTreeTypeRegistry::Lookup( const py::object&, const std::string&); -size_t PyTreeTypeRegistry::TypeHash::operator()(const py::object& t) const { - return std::hash{}(t.ptr()); -} -size_t PyTreeTypeRegistry::TypeHash::operator()(const py::handle& t) const { - return std::hash{}(t.ptr()); -} - -bool PyTreeTypeRegistry::TypeEq::operator()(const py::object& a, const py::object& b) const { - return a.ptr() == b.ptr(); -} -bool PyTreeTypeRegistry::TypeEq::operator()(const py::object& a, const py::handle& b) const { - return a.ptr() == b.ptr(); -} -bool PyTreeTypeRegistry::TypeEq::operator()(const py::handle& a, const py::object& b) const { - return a.ptr() == b.ptr(); -} -bool PyTreeTypeRegistry::TypeEq::operator()(const py::handle& a, const py::handle& b) const { - return a.ptr() == b.ptr(); -} - -size_t PyTreeTypeRegistry::NamedTypeHash::operator()( - const std::pair& p) const { - size_t seed = 0; - HashCombine(seed, p.first); - HashCombine(seed, p.second.ptr()); - return seed; -} -size_t PyTreeTypeRegistry::NamedTypeHash::operator()( - const std::pair& p) const { - size_t seed = 0; - HashCombine(seed, p.first); - HashCombine(seed, p.second.ptr()); - return seed; -} - -bool PyTreeTypeRegistry::NamedTypeEq::operator()( - const std::pair& a, - const std::pair& b) const { - return a.first == b.first && a.second.ptr() == b.second.ptr(); -} -bool PyTreeTypeRegistry::NamedTypeEq::operator()( - const std::pair& a, - const std::pair& b) const { - return a.first == b.first && a.second.ptr() == b.second.ptr(); -} -bool PyTreeTypeRegistry::NamedTypeEq::operator()( - const std::pair& a, - const std::pair& b) const { - return a.first == b.first && a.second.ptr() == b.second.ptr(); -} -bool PyTreeTypeRegistry::NamedTypeEq::operator()( - const std::pair& a, - const std::pair& b) const { - return a.first == b.first && a.second.ptr() == b.second.ptr(); -} - } // namespace optree diff --git a/tests/helpers.py b/tests/helpers.py index c8123668..ccb8199e 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -15,6 +15,7 @@ # pylint: disable=missing-class-docstring,missing-function-docstring,invalid-name +import gc import itertools import sys import time @@ -25,6 +26,12 @@ import optree +def getrefcount(obj=None): + for _ in range(10): + gc.collect() + return sys.getrefcount(obj) + + def parametrize(**argvalues) -> pytest.mark.parametrize: arguments = list(argvalues) argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments)))) diff --git a/tests/test_typing.py b/tests/test_typing.py index f82c1cbc..71dc2521 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -18,11 +18,12 @@ import re import sys import time +from collections import namedtuple import pytest import optree -from helpers import CustomNamedTupleSubclass, CustomTuple, Vector2D +from helpers import CustomNamedTupleSubclass, CustomTuple, Vector2D, getrefcount class FakeNamedTuple(tuple): @@ -82,6 +83,40 @@ def test_is_namedtuple(): assert not optree.is_namedtuple_class(FakeStructSequence) +def test_is_namedtuple_cache(): + Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024 + + refcount = getrefcount(Point) + assert optree.is_namedtuple(Point) + new_refcount = getrefcount(Point) + assert new_refcount == refcount + + refcount = getrefcount(time.struct_time) + assert not optree.is_namedtuple(time.struct_time) + new_refcount = getrefcount(time.struct_time) + assert new_refcount == refcount + + called_with = '' + + class FooMeta(type): + def __del__(cls): + nonlocal called_with + called_with = cls.__name__ + + class Foo(metaclass=FooMeta): + pass + + refcount = getrefcount(Foo) + assert not optree.is_namedtuple(Foo) + new_refcount = getrefcount(Foo) + assert new_refcount == refcount + + assert called_with == '' + del Foo + getrefcount() + assert called_with == 'Foo' + + def test_is_structseq(): with pytest.raises(TypeError, match="type 'structseq' is not an acceptable base type"): @@ -118,6 +153,40 @@ class MyTuple(optree.typing.structseq): assert not optree.is_structseq_class(FakeStructSequence) +def test_is_structseq_cache(): + Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024 + + refcount = getrefcount(Point) + assert not optree.is_structseq(Point) + new_refcount = getrefcount(Point) + assert new_refcount == refcount + + refcount = getrefcount(time.struct_time) + assert optree.is_structseq(time.struct_time) + new_refcount = getrefcount(time.struct_time) + assert new_refcount == refcount + + called_with = '' + + class FooMeta(type): + def __del__(cls): + nonlocal called_with + called_with = cls.__name__ + + class Foo(metaclass=FooMeta): + pass + + refcount = getrefcount(Foo) + assert not optree.is_namedtuple(Foo) + new_refcount = getrefcount(Foo) + assert new_refcount == refcount + + assert called_with == '' + del Foo + getrefcount() + assert called_with == 'Foo' + + def test_namedtuple_fields(): assert optree.namedtuple_fields(CustomTuple) == ('foo', 'bar') assert optree.namedtuple_fields(CustomTuple(1, 2)) == ('foo', 'bar')