Skip to content

Commit

Permalink
feat(src/utils): add cache to is_namedtuple and is_structseq
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Jan 2, 2024
1 parent 58eaf1a commit 8352ae9
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 99 deletions.
36 changes: 2 additions & 34 deletions include/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ limitations under the License.
#include <unordered_map> // std::unordered_map
#include <utility> // std::pair

#include "include/utils.h"

namespace optree {

namespace py = pybind11;
Expand Down Expand Up @@ -92,40 +94,6 @@ class PyTreeTypeRegistry {
const py::function &unflatten_func,
const std::string &registry_namespace);

class TypeHash {
public:
using is_transparent = void;
size_t operator()(const py::object &t) const;
size_t operator()(const py::handle &t) const;
};
class TypeEq {
public:
using is_transparent = void;
bool operator()(const py::object &a, const py::object &b) const;
bool operator()(const py::object &a, const py::handle &b) const;
bool operator()(const py::handle &a, const py::object &b) const;
bool operator()(const py::handle &a, const py::handle &b) const;
};

class NamedTypeHash {
public:
using is_transparent = void;
size_t operator()(const std::pair<std::string, py::object> &p) const;
size_t operator()(const std::pair<std::string, py::handle> &p) const;
};
class NamedTypeEq {
public:
using is_transparent = void;
bool operator()(const std::pair<std::string, py::object> &a,
const std::pair<std::string, py::object> &b) const;
bool operator()(const std::pair<std::string, py::object> &a,
const std::pair<std::string, py::handle> &b) const;
bool operator()(const std::pair<std::string, py::handle> &a,
const std::pair<std::string, py::object> &b) const;
bool operator()(const std::pair<std::string, py::handle> &a,
const std::pair<std::string, py::handle> &b) const;
};

std::unordered_map<py::object, std::unique_ptr<Registration>, TypeHash, TypeEq>
m_registrations{};
std::unordered_map<std::pair<std::string, py::object>,
Expand Down
99 changes: 91 additions & 8 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ limitations under the License.

#include <pybind11/pybind11.h>

#include <exception> // std::rethrow_exception, std::current_exception
#include <functional> // std::hash
#include <sstream> // std::ostringstream
#include <string> // std::string
#include <utility> // std::move, std::pair, std::make_pair
#include <vector> // std::vector
#include <exception> // std::rethrow_exception, std::current_exception
#include <functional> // std::hash
#include <sstream> // std::ostringstream
#include <string> // std::string
#include <unordered_map> // std::unordered_map
#include <utility> // std::move, std::pair, std::make_pair
#include <vector> // std::vector

namespace py = pybind11;
using size_t = py::size_t;
Expand All @@ -50,6 +51,58 @@ inline void HashCombine(py::ssize_t& seed, const T& v) { // NOLINT[runtime/refe
seed ^= (hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2));
}

class TypeHash {
public:
using is_transparent = void;
size_t operator()(const py::object& t) const { return std::hash<PyObject*>{}(t.ptr()); }
size_t operator()(const py::handle& t) const { return std::hash<PyObject*>{}(t.ptr()); }
};
class TypeEq {
public:
using is_transparent = void;
bool operator()(const py::object& a, const py::object& b) const { return a.ptr() == b.ptr(); }
bool operator()(const py::object& a, const py::handle& b) const { return a.ptr() == b.ptr(); }
bool operator()(const py::handle& a, const py::object& b) const { return a.ptr() == b.ptr(); }
bool operator()(const py::handle& a, const py::handle& b) const { return a.ptr() == b.ptr(); }
};

class NamedTypeHash {
public:
using is_transparent = void;
size_t operator()(const std::pair<std::string, py::object>& p) const {
size_t seed = 0;
HashCombine(seed, p.first);
HashCombine(seed, p.second.ptr());
return seed;
}
size_t operator()(const std::pair<std::string, py::handle>& p) const {
size_t seed = 0;
HashCombine(seed, p.first);
HashCombine(seed, p.second.ptr());
return seed;
}
};
class NamedTypeEq {
public:
using is_transparent = void;
bool operator()(const std::pair<std::string, py::object>& a,
const std::pair<std::string, py::object>& b) const {
return a.first == b.first && a.second.ptr() == b.second.ptr();
}
bool operator()(const std::pair<std::string, py::object>& a,
const std::pair<std::string, py::handle>& b) const {
return a.first == b.first && a.second.ptr() == b.second.ptr();
}
bool operator()(const std::pair<std::string, py::handle>& a,
const std::pair<std::string, py::object>& b) const {
return a.first == b.first && a.second.ptr() == b.second.ptr();
}
bool operator()(const std::pair<std::string, py::handle>& a,
const std::pair<std::string, py::handle>& b) const {
return a.first == b.first && a.second.ptr() == b.second.ptr();
}
};

constexpr bool NONE_IS_LEAF = true;
constexpr bool NONE_IS_NODE = false;

Expand Down Expand Up @@ -395,7 +448,22 @@ inline bool IsNamedTupleClassImpl(const py::handle& type) {
return false;
}
inline bool IsNamedTupleClass(const py::handle& type) {
return PyType_Check(type.ptr()) && IsNamedTupleClassImpl(type);
if (PyType_Check(type.ptr())) [[likely]] {
static auto cache = std::unordered_map<py::handle, bool, TypeHash, TypeEq>{};
auto it = cache.find(type);
if (it != cache.end()) [[likely]] {
return it->second;
}
bool result = IsNamedTupleClassImpl(type);
cache.emplace(type, result);
(void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void {
cache.erase(type);
weakref.dec_ref();
}))
.release();
return result;
}
return false;
}
inline bool IsNamedTupleInstance(const py::handle& object) {
return IsNamedTupleClass(py::type::handle_of(object));
Expand Down Expand Up @@ -458,7 +526,22 @@ inline bool IsStructSequenceClassImpl(const py::handle& type) {
return false;
}
inline bool IsStructSequenceClass(const py::handle& type) {
return PyType_Check(type.ptr()) && IsStructSequenceClassImpl(type);
if (PyType_Check(type.ptr())) [[likely]] {
static auto cache = std::unordered_map<py::handle, bool, TypeHash, TypeEq>{};
auto it = cache.find(type);
if (it != cache.end()) [[likely]] {
return it->second;
}
bool result = IsStructSequenceClassImpl(type);
cache.emplace(type, result);
(void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void {
cache.erase(type);
weakref.dec_ref();
}))
.release();
return result;
}
return false;
}
inline bool IsStructSequenceInstance(const py::handle& object) {
return IsStructSequenceClass(py::type::handle_of(object));
Expand Down
56 changes: 0 additions & 56 deletions src/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,60 +168,4 @@ template const PyTreeTypeRegistry::Registration* PyTreeTypeRegistry::Lookup<NONE
template const PyTreeTypeRegistry::Registration* PyTreeTypeRegistry::Lookup<NONE_IS_LEAF>(
const py::object&, const std::string&);

size_t PyTreeTypeRegistry::TypeHash::operator()(const py::object& t) const {
return std::hash<PyObject*>{}(t.ptr());
}
size_t PyTreeTypeRegistry::TypeHash::operator()(const py::handle& t) const {
return std::hash<PyObject*>{}(t.ptr());
}

bool PyTreeTypeRegistry::TypeEq::operator()(const py::object& a, const py::object& b) const {
return a.ptr() == b.ptr();
}
bool PyTreeTypeRegistry::TypeEq::operator()(const py::object& a, const py::handle& b) const {
return a.ptr() == b.ptr();
}
bool PyTreeTypeRegistry::TypeEq::operator()(const py::handle& a, const py::object& b) const {
return a.ptr() == b.ptr();
}
bool PyTreeTypeRegistry::TypeEq::operator()(const py::handle& a, const py::handle& b) const {
return a.ptr() == b.ptr();
}

size_t PyTreeTypeRegistry::NamedTypeHash::operator()(
const std::pair<std::string, py::object>& p) const {
size_t seed = 0;
HashCombine(seed, p.first);
HashCombine(seed, p.second.ptr());
return seed;
}
size_t PyTreeTypeRegistry::NamedTypeHash::operator()(
const std::pair<std::string, py::handle>& p) const {
size_t seed = 0;
HashCombine(seed, p.first);
HashCombine(seed, p.second.ptr());
return seed;
}

bool PyTreeTypeRegistry::NamedTypeEq::operator()(
const std::pair<std::string, py::object>& a,
const std::pair<std::string, py::object>& b) const {
return a.first == b.first && a.second.ptr() == b.second.ptr();
}
bool PyTreeTypeRegistry::NamedTypeEq::operator()(
const std::pair<std::string, py::object>& a,
const std::pair<std::string, py::handle>& b) const {
return a.first == b.first && a.second.ptr() == b.second.ptr();
}
bool PyTreeTypeRegistry::NamedTypeEq::operator()(
const std::pair<std::string, py::handle>& a,
const std::pair<std::string, py::object>& b) const {
return a.first == b.first && a.second.ptr() == b.second.ptr();
}
bool PyTreeTypeRegistry::NamedTypeEq::operator()(
const std::pair<std::string, py::handle>& a,
const std::pair<std::string, py::handle>& b) const {
return a.first == b.first && a.second.ptr() == b.second.ptr();
}

} // namespace optree
7 changes: 7 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# pylint: disable=missing-class-docstring,missing-function-docstring,invalid-name

import gc
import itertools
import sys
import time
Expand All @@ -25,6 +26,12 @@
import optree


def getrefcount(obj=None):
for _ in range(10):
gc.collect()
return sys.getrefcount(obj)


def parametrize(**argvalues) -> pytest.mark.parametrize:
arguments = list(argvalues)
argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments))))
Expand Down
71 changes: 70 additions & 1 deletion tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
import re
import sys
import time
from collections import namedtuple

import pytest

import optree
from helpers import CustomNamedTupleSubclass, CustomTuple, Vector2D
from helpers import CustomNamedTupleSubclass, CustomTuple, Vector2D, getrefcount


class FakeNamedTuple(tuple):
Expand Down Expand Up @@ -82,6 +83,40 @@ def test_is_namedtuple():
assert not optree.is_namedtuple_class(FakeStructSequence)


def test_is_namedtuple_cache():
Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024

refcount = getrefcount(Point)
assert optree.is_namedtuple(Point)
new_refcount = getrefcount(Point)
assert new_refcount == refcount

refcount = getrefcount(time.struct_time)
assert not optree.is_namedtuple(time.struct_time)
new_refcount = getrefcount(time.struct_time)
assert new_refcount == refcount

called_with = ''

class FooMeta(type):
def __del__(cls):
nonlocal called_with
called_with = cls.__name__

class Foo(metaclass=FooMeta):
pass

refcount = getrefcount(Foo)
assert not optree.is_namedtuple(Foo)
new_refcount = getrefcount(Foo)
assert new_refcount == refcount

assert called_with == ''
del Foo
getrefcount()
assert called_with == 'Foo'


def test_is_structseq():
with pytest.raises(TypeError, match="type 'structseq' is not an acceptable base type"):

Expand Down Expand Up @@ -118,6 +153,40 @@ class MyTuple(optree.typing.structseq):
assert not optree.is_structseq_class(FakeStructSequence)


def test_is_structseq_cache():
Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024

refcount = getrefcount(Point)
assert not optree.is_structseq(Point)
new_refcount = getrefcount(Point)
assert new_refcount == refcount

refcount = getrefcount(time.struct_time)
assert optree.is_structseq(time.struct_time)
new_refcount = getrefcount(time.struct_time)
assert new_refcount == refcount

called_with = ''

class FooMeta(type):
def __del__(cls):
nonlocal called_with
called_with = cls.__name__

class Foo(metaclass=FooMeta):
pass

refcount = getrefcount(Foo)
assert not optree.is_namedtuple(Foo)
new_refcount = getrefcount(Foo)
assert new_refcount == refcount

assert called_with == ''
del Foo
getrefcount()
assert called_with == 'Foo'


def test_namedtuple_fields():
assert optree.namedtuple_fields(CustomTuple) == ('foo', 'bar')
assert optree.namedtuple_fields(CustomTuple(1, 2)) == ('foo', 'bar')
Expand Down

0 comments on commit 8352ae9

Please sign in to comment.