diff --git a/include/registry.h b/include/registry.h index 43dadb9e..85efb1f7 100644 --- a/include/registry.h +++ b/include/registry.h @@ -33,16 +33,16 @@ limitations under the License. namespace optree { enum class PyTreeKind { - Custom, // A custom type - Leaf, // An opaque leaf node - None, // None - Tuple, // A tuple - List, // A list - Dict, // A dict - NamedTuple, // A collections.namedtuple - OrderedDict, // A collections.OrderedDict - DefaultDict, // A collections.defaultdict - Deque, // A collections.deque + CUSTOM = 0, // A custom type + LEAF, // An opaque leaf node + NONE, // None + TUPLE, // A tuple + LIST, // A list + DICT, // A dict + NAMED_TUPLE, // A collections.namedtuple + ORDERED_DICT, // A collections.OrderedDict + DEFAULT_DICT, // A collections.defaultdict + DEQUE, // A collections.deque }; // Registry of custom node types. diff --git a/include/treespec.h b/include/treespec.h index 38d9998c..eefc70db 100644 --- a/include/treespec.h +++ b/include/treespec.h @@ -154,7 +154,7 @@ class PyTreeSpec { private: struct Node { - PyTreeKind kind = PyTreeKind::Leaf; + PyTreeKind kind = PyTreeKind::LEAF; // Arity for non-Leaf types. ssize_t arity = 0; diff --git a/src/registry.cpp b/src/registry.cpp index 13f4f5f5..ba23de7b 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -34,16 +34,16 @@ template <> registration->type = type; CHECK(registry.m_registrations.emplace(type, std::move(registration)).second); }; - add_builtin_type(Py_TYPE(Py_None), PyTreeKind::None); - add_builtin_type(&PyTuple_Type, PyTreeKind::Tuple); - add_builtin_type(&PyList_Type, PyTreeKind::List); - add_builtin_type(&PyDict_Type, PyTreeKind::Dict); + add_builtin_type(Py_TYPE(Py_None), PyTreeKind::NONE); + add_builtin_type(&PyTuple_Type, PyTreeKind::TUPLE); + add_builtin_type(&PyList_Type, PyTreeKind::LIST); + add_builtin_type(&PyDict_Type, PyTreeKind::DICT); add_builtin_type(reinterpret_cast(PyOrderedDictTypeObject.ptr()), - PyTreeKind::OrderedDict); + PyTreeKind::ORDERED_DICT); add_builtin_type(reinterpret_cast(PyDefaultDictTypeObject.ptr()), - PyTreeKind::DefaultDict); + PyTreeKind::DEFAULT_DICT); add_builtin_type(reinterpret_cast(PyDequeTypeObject.ptr()), - PyTreeKind::Deque); + PyTreeKind::DEQUE); return registry; }()); return ®istry; @@ -61,15 +61,15 @@ template <> registration->type = type; CHECK(registry.m_registrations.emplace(type, std::move(registration)).second); }; - add_builtin_type(&PyTuple_Type, PyTreeKind::Tuple); - add_builtin_type(&PyList_Type, PyTreeKind::List); - add_builtin_type(&PyDict_Type, PyTreeKind::Dict); + add_builtin_type(&PyTuple_Type, PyTreeKind::TUPLE); + add_builtin_type(&PyList_Type, PyTreeKind::LIST); + add_builtin_type(&PyDict_Type, PyTreeKind::DICT); add_builtin_type(reinterpret_cast(PyOrderedDictTypeObject.ptr()), - PyTreeKind::OrderedDict); + PyTreeKind::ORDERED_DICT); add_builtin_type(reinterpret_cast(PyDefaultDictTypeObject.ptr()), - PyTreeKind::DefaultDict); + PyTreeKind::DEFAULT_DICT); add_builtin_type(reinterpret_cast(PyDequeTypeObject.ptr()), - PyTreeKind::Deque); + PyTreeKind::DEQUE); return registry; }()); return ®istry; @@ -87,7 +87,7 @@ template { PyTreeTypeRegistry* registry = Singleton(); auto registration = std::make_unique(); - registration->kind = PyTreeKind::Custom; + registration->kind = PyTreeKind::CUSTOM; registration->type = py::reinterpret_borrow(cls); registration->to_iterable = py::reinterpret_borrow(to_iterable); registration->from_iterable = py::reinterpret_borrow(from_iterable); @@ -118,7 +118,7 @@ template { PyTreeTypeRegistry* registry = Singleton(); auto registration = std::make_unique(); - registration->kind = PyTreeKind::Custom; + registration->kind = PyTreeKind::CUSTOM; registration->type = py::reinterpret_borrow(cls); registration->to_iterable = py::reinterpret_borrow(to_iterable); registration->from_iterable = py::reinterpret_borrow(from_iterable); diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index 4869c1e3..47f0e51b 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -48,13 +48,13 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, child, leaves, depth + 1, leaf_predicate, registry_namespace); }; switch (node.kind) { - case PyTreeKind::None: + case PyTreeKind::NONE: if (!NoneIsLeaf) break; - case PyTreeKind::Leaf: + case PyTreeKind::LEAF: leaves.emplace_back(py::reinterpret_borrow(handle)); break; - case PyTreeKind::Tuple: { + case PyTreeKind::TUPLE: { node.arity = GET_SIZE(handle); for (ssize_t i = 0; i < node.arity; ++i) { recurse(GET_ITEM_HANDLE(handle, i)); @@ -62,7 +62,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, break; } - case PyTreeKind::List: { + case PyTreeKind::LIST: { node.arity = GET_SIZE(handle); for (ssize_t i = 0; i < node.arity; ++i) { recurse(GET_ITEM_HANDLE(handle, i)); @@ -70,12 +70,12 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, break; } - case PyTreeKind::Dict: - case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: { + case PyTreeKind::DICT: + case PyTreeKind::ORDERED_DICT: + case PyTreeKind::DEFAULT_DICT: { py::dict dict = py::reinterpret_borrow(handle); py::list keys; - if (node.kind == PyTreeKind::OrderedDict) [[unlikely]] { + if (node.kind == PyTreeKind::ORDERED_DICT) [[unlikely]] { keys = DictKeys(dict); } else [[likely]] { // NOLINT keys = SortedDictKeys(dict); @@ -84,7 +84,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, recurse(dict[key]); } node.arity = GET_SIZE(dict); - if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { + if (node.kind == PyTreeKind::DEFAULT_DICT) [[unlikely]] { node.node_data = py::make_tuple(py::getattr(handle, "default_factory"), std::move(keys)); } else [[likely]] { // NOLINT @@ -93,7 +93,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, break; } - case PyTreeKind::NamedTuple: { + case PyTreeKind::NAMED_TUPLE: { py::tuple tuple = py::reinterpret_borrow(handle); node.arity = GET_SIZE(tuple); node.node_data = py::reinterpret_borrow(tuple.get_type()); @@ -103,7 +103,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, break; } - case PyTreeKind::Deque: { + case PyTreeKind::DEQUE: { py::list list = handle.cast(); node.arity = GET_SIZE(list); node.node_data = py::getattr(handle, "maxlen"); @@ -113,7 +113,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, break; } - case PyTreeKind::Custom: { + case PyTreeKind::CUSTOM: { found_custom = true; py::tuple out = py::cast(node.custom->to_iterable(handle)); const size_t num_out = GET_SIZE(out); @@ -218,9 +218,9 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, stack.pop_back(); }; switch (node.kind) { - case PyTreeKind::None: + case PyTreeKind::NONE: if (!NoneIsLeaf) break; - case PyTreeKind::Leaf: { + case PyTreeKind::LEAF: { py::tuple path = py::tuple{depth}; for (ssize_t d = 0; d < depth; ++d) { SET_ITEM(path, d, stack[d]); @@ -230,7 +230,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, break; } - case PyTreeKind::Tuple: { + case PyTreeKind::TUPLE: { node.arity = GET_SIZE(handle); for (ssize_t i = 0; i < node.arity; ++i) { recurse(GET_ITEM_HANDLE(handle, i), py::int_(i)); @@ -238,7 +238,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, break; } - case PyTreeKind::List: { + case PyTreeKind::LIST: { node.arity = GET_SIZE(handle); for (ssize_t i = 0; i < node.arity; ++i) { recurse(GET_ITEM_HANDLE(handle, i), py::int_(i)); @@ -246,12 +246,12 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, break; } - case PyTreeKind::Dict: - case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: { + case PyTreeKind::DICT: + case PyTreeKind::ORDERED_DICT: + case PyTreeKind::DEFAULT_DICT: { py::dict dict = py::reinterpret_borrow(handle); py::list keys; - if (node.kind == PyTreeKind::OrderedDict) [[unlikely]] { + if (node.kind == PyTreeKind::ORDERED_DICT) [[unlikely]] { keys = DictKeys(dict); } else [[likely]] { // NOLINT keys = SortedDictKeys(dict); @@ -260,7 +260,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, recurse(dict[key], key); } node.arity = GET_SIZE(dict); - if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { + if (node.kind == PyTreeKind::DEFAULT_DICT) [[unlikely]] { node.node_data = py::make_tuple(py::getattr(handle, "default_factory"), std::move(keys)); } else [[likely]] { // NOLINT @@ -269,7 +269,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, break; } - case PyTreeKind::NamedTuple: { + case PyTreeKind::NAMED_TUPLE: { py::tuple tuple = py::reinterpret_borrow(handle); node.arity = GET_SIZE(tuple); node.node_data = py::reinterpret_borrow(tuple.get_type()); @@ -279,7 +279,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, break; } - case PyTreeKind::Deque: { + case PyTreeKind::DEQUE: { py::list list = handle.cast(); node.arity = GET_SIZE(list); node.node_data = py::getattr(handle, "maxlen"); @@ -289,7 +289,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, break; } - case PyTreeKind::Custom: { + case PyTreeKind::CUSTOM: { found_custom = true; py::tuple out = py::cast(node.custom->to_iterable(handle)); const size_t num_out = GET_SIZE(out); @@ -398,10 +398,10 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { ++it; switch (node.kind) { - case PyTreeKind::None: + case PyTreeKind::NONE: break; - case PyTreeKind::Leaf: + case PyTreeKind::LEAF: if (leaf < 0) [[unlikely]] { throw std::logic_error("Leaf count mismatch."); } @@ -409,7 +409,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { --leaf; break; - case PyTreeKind::Tuple: { + case PyTreeKind::TUPLE: { AssertExact(object); py::tuple tuple = py::reinterpret_borrow(object); if ((ssize_t)GET_SIZE(tuple) != node.arity) [[unlikely]] { @@ -425,7 +425,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { break; } - case PyTreeKind::List: { + case PyTreeKind::LIST: { AssertExact(object); py::list list = py::reinterpret_borrow(object); if ((ssize_t)GET_SIZE(list) != node.arity) [[unlikely]] { @@ -441,16 +441,16 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { break; } - case PyTreeKind::Dict: - case PyTreeKind::OrderedDict: { - if (node.kind == PyTreeKind::OrderedDict) [[unlikely]] { + case PyTreeKind::DICT: + case PyTreeKind::ORDERED_DICT: { + if (node.kind == PyTreeKind::ORDERED_DICT) [[unlikely]] { AssertExactOrderedDict(object); } else [[likely]] { // NOLINT AssertExact(object); } py::dict dict = py::reinterpret_borrow(object); py::list keys; - if (node.kind == PyTreeKind::OrderedDict) [[unlikely]] { + if (node.kind == PyTreeKind::ORDERED_DICT) [[unlikely]] { keys = DictKeys(dict); } else [[likely]] { // NOLINT keys = SortedDictKeys(dict); @@ -467,7 +467,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { break; } - case PyTreeKind::NamedTuple: { + case PyTreeKind::NAMED_TUPLE: { AssertExactNamedTuple(object); py::tuple tuple = py::reinterpret_borrow(object); if ((ssize_t)GET_SIZE(tuple) != node.arity) [[unlikely]] { @@ -489,7 +489,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { break; } - case PyTreeKind::DefaultDict: { + case PyTreeKind::DEFAULT_DICT: { AssertExactDefaultDict(object); py::dict dict = py::reinterpret_borrow(object); py::list keys = SortedDictKeys(dict); @@ -515,7 +515,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { break; } - case PyTreeKind::Deque: { + case PyTreeKind::DEQUE: { AssertExactDeque(object); py::list list = py::cast(object); if ((ssize_t)GET_SIZE(list) != node.arity) [[unlikely]] { @@ -531,7 +531,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { break; } - case PyTreeKind::Custom: { + case PyTreeKind::CUSTOM: { const PyTreeTypeRegistry::Registration* registration; if (m_none_is_leaf) [[unlikely]] { registration = @@ -595,7 +595,7 @@ template const std::string& registry_namespace) { const PyTreeTypeRegistry::Registration* custom; for (const py::handle& h : iterable) { - if (GetKind(h, &custom, registry_namespace) != PyTreeKind::Leaf) [[unlikely]] { + if (GetKind(h, &custom, registry_namespace) != PyTreeKind::LEAF) [[unlikely]] { return false; } } diff --git a/src/treespec/traversal.cpp b/src/treespec/traversal.cpp index 21301e77..c2f44ef3 100644 --- a/src/treespec/traversal.cpp +++ b/src/treespec/traversal.cpp @@ -31,7 +31,7 @@ py::object PyTreeSpec::Walk(const py::function& f_node, const bool f_leaf_identity = f_leaf.is_none(); for (const Node& node : m_traversal) { switch (node.kind) { - case PyTreeKind::Leaf: { + case PyTreeKind::LEAF: { if (it == leaves.end()) [[unlikely]] { throw std::invalid_argument("Too few leaves for PyTreeSpec"); } @@ -42,15 +42,15 @@ py::object PyTreeSpec::Walk(const py::function& f_node, break; } - case PyTreeKind::None: - case PyTreeKind::Tuple: - case PyTreeKind::NamedTuple: - case PyTreeKind::List: - case PyTreeKind::Dict: - case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: - case PyTreeKind::Deque: - case PyTreeKind::Custom: { + case PyTreeKind::NONE: + case PyTreeKind::TUPLE: + case PyTreeKind::NAMED_TUPLE: + case PyTreeKind::LIST: + case PyTreeKind::DICT: + case PyTreeKind::ORDERED_DICT: + case PyTreeKind::DEFAULT_DICT: + case PyTreeKind::DEQUE: + case PyTreeKind::CUSTOM: { if ((ssize_t)agenda.size() < node.arity) [[unlikely]] { throw std::logic_error("Too few elements for custom type."); } diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index d4e8d768..64b90aea 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -83,7 +83,7 @@ std::unique_ptr PyTreeSpec::Compose(const PyTreeSpec& inner_treespec outer_treespec->m_namespace = inner_treespec.m_namespace; } for (const Node& node : m_traversal) { - if (node.kind == PyTreeKind::Leaf) [[likely]] { + if (node.kind == PyTreeKind::LEAF) [[likely]] { absl::c_copy(inner_treespec.m_traversal, std::back_inserter(outer_treespec->m_traversal)); } else [[unlikely]] { // NOLINT @@ -126,7 +126,7 @@ std::unique_ptr PyTreeSpec::Compose(const PyTreeSpec& inner_treespec num_leaves += treespec.num_leaves(); } Node node; - node.kind = PyTreeKind::Tuple; + node.kind = PyTreeKind::TUPLE; node.arity = (ssize_t)treespecs.size(); node.num_leaves = num_leaves; node.num_nodes = (ssize_t)out->m_traversal.size() + 1; @@ -139,7 +139,7 @@ std::unique_ptr PyTreeSpec::Compose(const PyTreeSpec& inner_treespec /*static*/ std::unique_ptr PyTreeSpec::Leaf(const bool& none_is_leaf) { auto out = std::make_unique(); Node node; - node.kind = PyTreeKind::Leaf; + node.kind = PyTreeKind::LEAF; node.arity = 0; node.num_leaves = 1; node.num_nodes = 1; @@ -154,7 +154,7 @@ std::unique_ptr PyTreeSpec::Compose(const PyTreeSpec& inner_treespec } auto out = std::make_unique(); Node node; - node.kind = PyTreeKind::None; + node.kind = PyTreeKind::NONE; node.arity = 0; node.num_leaves = 0; node.num_nodes = 1; @@ -196,37 +196,37 @@ std::vector> PyTreeSpec::Children() const { throw std::logic_error("Node arity did not match."); } switch (node.kind) { - case PyTreeKind::Leaf: + case PyTreeKind::LEAF: throw std::logic_error("MakeNode not implemented for leaves."); - case PyTreeKind::None: + case PyTreeKind::NONE: return py::none(); - case PyTreeKind::Tuple: - case PyTreeKind::NamedTuple: { + case PyTreeKind::TUPLE: + case PyTreeKind::NAMED_TUPLE: { py::tuple tuple{node.arity}; for (ssize_t i = 0; i < node.arity; ++i) { SET_ITEM(tuple, i, children[i]); } - if (node.kind == PyTreeKind::NamedTuple) [[unlikely]] { + if (node.kind == PyTreeKind::NAMED_TUPLE) [[unlikely]] { return node.node_data(*tuple); } return std::move(tuple); } - case PyTreeKind::List: - case PyTreeKind::Deque: { + case PyTreeKind::LIST: + case PyTreeKind::DEQUE: { py::list list{node.arity}; for (ssize_t i = 0; i < node.arity; ++i) { SET_ITEM(list, i, children[i]); } - if (node.kind == PyTreeKind::Deque) [[unlikely]] { + if (node.kind == PyTreeKind::DEQUE) [[unlikely]] { return PyDequeTypeObject(list, py::arg("maxlen") = node.node_data); } return std::move(list); } - case PyTreeKind::Dict: { + case PyTreeKind::DICT: { py::dict dict; py::list keys = py::reinterpret_borrow(node.node_data); for (ssize_t i = 0; i < node.arity; ++i) { @@ -235,7 +235,7 @@ std::vector> PyTreeSpec::Children() const { return std::move(dict); } - case PyTreeKind::OrderedDict: { + case PyTreeKind::ORDERED_DICT: { py::list items{node.arity}; py::list keys = py::reinterpret_borrow(node.node_data); for (ssize_t i = 0; i < node.arity; ++i) { @@ -247,7 +247,7 @@ std::vector> PyTreeSpec::Children() const { return PyOrderedDictTypeObject(items); } - case PyTreeKind::DefaultDict: { + case PyTreeKind::DEFAULT_DICT: { py::dict dict; py::object default_factory = GET_ITEM_BORROW(node.node_data, 0); py::list keys = GET_ITEM_BORROW(node.node_data, 1); @@ -257,7 +257,7 @@ std::vector> PyTreeSpec::Children() const { return PyDefaultDictTypeObject(default_factory, dict); } - case PyTreeKind::Custom: { + case PyTreeKind::CUSTOM: { py::tuple tuple{node.arity}; for (ssize_t i = 0; i < node.arity; ++i) { SET_ITEM(tuple, i, children[i]); @@ -277,7 +277,7 @@ template const PyTreeTypeRegistry::Registration* registration = PyTreeTypeRegistry::Lookup(handle.get_type(), registry_namespace); if (registration) [[likely]] { // NOLINT - if (registration->kind == PyTreeKind::Custom) [[unlikely]] { + if (registration->kind == PyTreeKind::CUSTOM) [[unlikely]] { *custom = registration; } else [[likely]] { // NOLINT *custom = nullptr; @@ -286,9 +286,9 @@ template } *custom = nullptr; if (IsNamedTuple(handle)) [[unlikely]] { - return PyTreeKind::NamedTuple; + return PyTreeKind::NAMED_TUPLE; } - return PyTreeKind::Leaf; + return PyTreeKind::LEAF; } template PyTreeKind PyTreeSpec::GetKind(const py::handle&, @@ -308,15 +308,15 @@ std::string PyTreeSpec::ToString() const { std::string children = absl::StrJoin(agenda.end() - node.arity, agenda.end(), ", "); std::string representation; switch (node.kind) { - case PyTreeKind::Leaf: + case PyTreeKind::LEAF: agenda.emplace_back("*"); continue; - case PyTreeKind::None: + case PyTreeKind::NONE: representation = "None"; break; - case PyTreeKind::Tuple: + case PyTreeKind::TUPLE: // Tuples with only one element must have a trailing comma. if (node.arity == 1) [[unlikely]] { children += ","; @@ -324,10 +324,10 @@ std::string PyTreeSpec::ToString() const { representation = absl::StrCat("(", children, ")"); break; - case PyTreeKind::List: - case PyTreeKind::Deque: + case PyTreeKind::LIST: + case PyTreeKind::DEQUE: representation = absl::StrCat("[", children, "]"); - if (node.kind == PyTreeKind::Deque) [[unlikely]] { // NOLINT + if (node.kind == PyTreeKind::DEQUE) [[unlikely]] { // NOLINT if (node.node_data.is_none()) [[likely]] { representation = absl::StrCat("deque(", representation, ")"); } else [[unlikely]] { // NOLINT @@ -337,7 +337,7 @@ std::string PyTreeSpec::ToString() const { } break; - case PyTreeKind::Dict: { + case PyTreeKind::DICT: { if ((ssize_t)py::len(node.node_data) != node.arity) [[unlikely]] { throw std::logic_error("Number of keys and entries does not match."); } @@ -354,7 +354,7 @@ std::string PyTreeSpec::ToString() const { break; } - case PyTreeKind::NamedTuple: { + case PyTreeKind::NAMED_TUPLE: { py::object type = node.node_data; py::tuple fields = py::reinterpret_borrow(py::getattr(type, "_fields")); if ((ssize_t)py::len(fields) != node.arity) [[unlikely]] { @@ -378,7 +378,7 @@ std::string PyTreeSpec::ToString() const { break; } - case PyTreeKind::OrderedDict: { + case PyTreeKind::ORDERED_DICT: { if ((ssize_t)py::len(node.node_data) != node.arity) [[unlikely]] { throw std::logic_error("Number of keys and entries does not match."); } @@ -395,7 +395,7 @@ std::string PyTreeSpec::ToString() const { break; } - case PyTreeKind::DefaultDict: { + case PyTreeKind::DEFAULT_DICT: { if ((ssize_t)GET_SIZE(node.node_data) != 2) [[unlikely]] { throw std::logic_error("Number of auxiliary data mismatch."); } @@ -418,7 +418,7 @@ std::string PyTreeSpec::ToString() const { break; } - case PyTreeKind::Custom: { + case PyTreeKind::CUSTOM: { py::object type = node.custom->type; std::string kind = static_cast(py::str(py::getattr(type, "__name__"))); std::string data; @@ -486,16 +486,16 @@ py::object PyTreeSpec::ToPicklable() const { node.kind = static_cast(t[0].cast()); node.arity = t[1].cast(); switch (node.kind) { - case PyTreeKind::NamedTuple: + case PyTreeKind::NAMED_TUPLE: node.node_data = t[2].cast(); break; - case PyTreeKind::Dict: - case PyTreeKind::OrderedDict: + case PyTreeKind::DICT: + case PyTreeKind::ORDERED_DICT: node.node_data = t[2].cast(); break; - case PyTreeKind::DefaultDict: - case PyTreeKind::Deque: - case PyTreeKind::Custom: + case PyTreeKind::DEFAULT_DICT: + case PyTreeKind::DEQUE: + case PyTreeKind::CUSTOM: node.node_data = t[2]; break; default: @@ -504,7 +504,7 @@ py::object PyTreeSpec::ToPicklable() const { } break; } - if (node.kind == PyTreeKind::Custom) [[unlikely]] { // NOLINT + if (node.kind == PyTreeKind::CUSTOM) [[unlikely]] { // NOLINT if (!t[3].is_none()) [[unlikely]] { node.node_entries = t[3].cast(); } diff --git a/src/treespec/unflatten.cpp b/src/treespec/unflatten.cpp index 0658d22b..03ed5894 100644 --- a/src/treespec/unflatten.cpp +++ b/src/treespec/unflatten.cpp @@ -32,9 +32,9 @@ py::object PyTreeSpec::UnflattenImpl(const Span& leaves) const { throw std::logic_error("Too few elements for PyTreeSpec node."); } switch (node.kind) { - case PyTreeKind::None: - case PyTreeKind::Leaf: { - if (node.kind == PyTreeKind::Leaf || m_none_is_leaf) [[likely]] { // NOLINT + case PyTreeKind::NONE: + case PyTreeKind::LEAF: { + if (node.kind == PyTreeKind::LEAF || m_none_is_leaf) [[likely]] { // NOLINT if (it == leaves.end()) [[unlikely]] { throw std::invalid_argument( absl::StrFormat("Too few leaves for PyTreeSpec; expected %ld, got %ld.", @@ -48,14 +48,14 @@ py::object PyTreeSpec::UnflattenImpl(const Span& leaves) const { } } - case PyTreeKind::Tuple: - case PyTreeKind::NamedTuple: - case PyTreeKind::List: - case PyTreeKind::Dict: - case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: - case PyTreeKind::Deque: - case PyTreeKind::Custom: { + case PyTreeKind::TUPLE: + case PyTreeKind::NAMED_TUPLE: + case PyTreeKind::LIST: + case PyTreeKind::DICT: + case PyTreeKind::ORDERED_DICT: + case PyTreeKind::DEFAULT_DICT: + case PyTreeKind::DEQUE: + case PyTreeKind::CUSTOM: { const ssize_t size = (ssize_t)agenda.size(); absl::Span span; if (node.arity > 0) [[likely]] {