Skip to content

Commit

Permalink
refactor: use UPPER_CASE name for enum constants
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 22, 2022
1 parent 5b49e6c commit c70cd6a
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 121 deletions.
20 changes: 10 additions & 10 deletions include/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ limitations under the License.
namespace optree {

enum class PyTreeKind {
Custom, // A custom type
Leaf, // An opaque leaf node
None, // None
Tuple, // A tuple
List, // A list
Dict, // A dict
NamedTuple, // A collections.namedtuple
OrderedDict, // A collections.OrderedDict
DefaultDict, // A collections.defaultdict
Deque, // A collections.deque
CUSTOM = 0, // A custom type
LEAF, // An opaque leaf node
NONE, // None
TUPLE, // A tuple
LIST, // A list
DICT, // A dict
NAMED_TUPLE, // A collections.namedtuple
ORDERED_DICT, // A collections.OrderedDict
DEFAULT_DICT, // A collections.defaultdict
DEQUE, // A collections.deque
};

// Registry of custom node types.
Expand Down
2 changes: 1 addition & 1 deletion include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class PyTreeSpec {

private:
struct Node {
PyTreeKind kind = PyTreeKind::Leaf;
PyTreeKind kind = PyTreeKind::LEAF;

// Arity for non-Leaf types.
ssize_t arity = 0;
Expand Down
30 changes: 15 additions & 15 deletions src/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ template <>
registration->type = type;
CHECK(registry.m_registrations.emplace(type, std::move(registration)).second);
};
add_builtin_type(Py_TYPE(Py_None), PyTreeKind::None);
add_builtin_type(&PyTuple_Type, PyTreeKind::Tuple);
add_builtin_type(&PyList_Type, PyTreeKind::List);
add_builtin_type(&PyDict_Type, PyTreeKind::Dict);
add_builtin_type(Py_TYPE(Py_None), PyTreeKind::NONE);
add_builtin_type(&PyTuple_Type, PyTreeKind::TUPLE);
add_builtin_type(&PyList_Type, PyTreeKind::LIST);
add_builtin_type(&PyDict_Type, PyTreeKind::DICT);
add_builtin_type(reinterpret_cast<PyTypeObject*>(PyOrderedDictTypeObject.ptr()),
PyTreeKind::OrderedDict);
PyTreeKind::ORDERED_DICT);
add_builtin_type(reinterpret_cast<PyTypeObject*>(PyDefaultDictTypeObject.ptr()),
PyTreeKind::DefaultDict);
PyTreeKind::DEFAULT_DICT);
add_builtin_type(reinterpret_cast<PyTypeObject*>(PyDequeTypeObject.ptr()),
PyTreeKind::Deque);
PyTreeKind::DEQUE);
return registry;
}());
return &registry;
Expand All @@ -61,15 +61,15 @@ template <>
registration->type = type;
CHECK(registry.m_registrations.emplace(type, std::move(registration)).second);
};
add_builtin_type(&PyTuple_Type, PyTreeKind::Tuple);
add_builtin_type(&PyList_Type, PyTreeKind::List);
add_builtin_type(&PyDict_Type, PyTreeKind::Dict);
add_builtin_type(&PyTuple_Type, PyTreeKind::TUPLE);
add_builtin_type(&PyList_Type, PyTreeKind::LIST);
add_builtin_type(&PyDict_Type, PyTreeKind::DICT);
add_builtin_type(reinterpret_cast<PyTypeObject*>(PyOrderedDictTypeObject.ptr()),
PyTreeKind::OrderedDict);
PyTreeKind::ORDERED_DICT);
add_builtin_type(reinterpret_cast<PyTypeObject*>(PyDefaultDictTypeObject.ptr()),
PyTreeKind::DefaultDict);
PyTreeKind::DEFAULT_DICT);
add_builtin_type(reinterpret_cast<PyTypeObject*>(PyDequeTypeObject.ptr()),
PyTreeKind::Deque);
PyTreeKind::DEQUE);
return registry;
}());
return &registry;
Expand All @@ -87,7 +87,7 @@ template <bool NoneIsLeaf>
{
PyTreeTypeRegistry* registry = Singleton<NONE_IS_NODE>();
auto registration = std::make_unique<Registration>();
registration->kind = PyTreeKind::Custom;
registration->kind = PyTreeKind::CUSTOM;
registration->type = py::reinterpret_borrow<py::object>(cls);
registration->to_iterable = py::reinterpret_borrow<py::function>(to_iterable);
registration->from_iterable = py::reinterpret_borrow<py::function>(from_iterable);
Expand Down Expand Up @@ -118,7 +118,7 @@ template <bool NoneIsLeaf>
{
PyTreeTypeRegistry* registry = Singleton<NONE_IS_LEAF>();
auto registration = std::make_unique<Registration>();
registration->kind = PyTreeKind::Custom;
registration->kind = PyTreeKind::CUSTOM;
registration->type = py::reinterpret_borrow<py::object>(cls);
registration->to_iterable = py::reinterpret_borrow<py::function>(to_iterable);
registration->from_iterable = py::reinterpret_borrow<py::function>(from_iterable);
Expand Down
74 changes: 37 additions & 37 deletions src/treespec/flatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,34 +48,34 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle,
child, leaves, depth + 1, leaf_predicate, registry_namespace);
};
switch (node.kind) {
case PyTreeKind::None:
case PyTreeKind::NONE:
if (!NoneIsLeaf) break;
case PyTreeKind::Leaf:
case PyTreeKind::LEAF:
leaves.emplace_back(py::reinterpret_borrow<py::object>(handle));
break;

case PyTreeKind::Tuple: {
case PyTreeKind::TUPLE: {
node.arity = GET_SIZE<py::tuple>(handle);
for (ssize_t i = 0; i < node.arity; ++i) {
recurse(GET_ITEM_HANDLE<py::tuple>(handle, i));
}
break;
}

case PyTreeKind::List: {
case PyTreeKind::LIST: {
node.arity = GET_SIZE<py::list>(handle);
for (ssize_t i = 0; i < node.arity; ++i) {
recurse(GET_ITEM_HANDLE<py::list>(handle, i));
}
break;
}

case PyTreeKind::Dict:
case PyTreeKind::OrderedDict:
case PyTreeKind::DefaultDict: {
case PyTreeKind::DICT:
case PyTreeKind::ORDERED_DICT:
case PyTreeKind::DEFAULT_DICT: {
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
py::list keys;
if (node.kind == PyTreeKind::OrderedDict) [[unlikely]] {
if (node.kind == PyTreeKind::ORDERED_DICT) [[unlikely]] {
keys = DictKeys(dict);
} else [[likely]] { // NOLINT
keys = SortedDictKeys(dict);
Expand All @@ -84,7 +84,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle,
recurse(dict[key]);
}
node.arity = GET_SIZE<py::dict>(dict);
if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] {
if (node.kind == PyTreeKind::DEFAULT_DICT) [[unlikely]] {
node.node_data =
py::make_tuple(py::getattr(handle, "default_factory"), std::move(keys));
} else [[likely]] { // NOLINT
Expand All @@ -93,7 +93,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle,
break;
}

case PyTreeKind::NamedTuple: {
case PyTreeKind::NAMED_TUPLE: {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = GET_SIZE<py::tuple>(tuple);
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
Expand All @@ -103,7 +103,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle,
break;
}

case PyTreeKind::Deque: {
case PyTreeKind::DEQUE: {
py::list list = handle.cast<py::list>();
node.arity = GET_SIZE<py::list>(list);
node.node_data = py::getattr(handle, "maxlen");
Expand All @@ -113,7 +113,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle,
break;
}

case PyTreeKind::Custom: {
case PyTreeKind::CUSTOM: {
found_custom = true;
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
const size_t num_out = GET_SIZE<py::tuple>(out);
Expand Down Expand Up @@ -218,9 +218,9 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle,
stack.pop_back();
};
switch (node.kind) {
case PyTreeKind::None:
case PyTreeKind::NONE:
if (!NoneIsLeaf) break;
case PyTreeKind::Leaf: {
case PyTreeKind::LEAF: {
py::tuple path = py::tuple{depth};
for (ssize_t d = 0; d < depth; ++d) {
SET_ITEM<py::tuple>(path, d, stack[d]);
Expand All @@ -230,28 +230,28 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle,
break;
}

case PyTreeKind::Tuple: {
case PyTreeKind::TUPLE: {
node.arity = GET_SIZE<py::tuple>(handle);
for (ssize_t i = 0; i < node.arity; ++i) {
recurse(GET_ITEM_HANDLE<py::tuple>(handle, i), py::int_(i));
}
break;
}

case PyTreeKind::List: {
case PyTreeKind::LIST: {
node.arity = GET_SIZE<py::list>(handle);
for (ssize_t i = 0; i < node.arity; ++i) {
recurse(GET_ITEM_HANDLE<py::list>(handle, i), py::int_(i));
}
break;
}

case PyTreeKind::Dict:
case PyTreeKind::OrderedDict:
case PyTreeKind::DefaultDict: {
case PyTreeKind::DICT:
case PyTreeKind::ORDERED_DICT:
case PyTreeKind::DEFAULT_DICT: {
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
py::list keys;
if (node.kind == PyTreeKind::OrderedDict) [[unlikely]] {
if (node.kind == PyTreeKind::ORDERED_DICT) [[unlikely]] {
keys = DictKeys(dict);
} else [[likely]] { // NOLINT
keys = SortedDictKeys(dict);
Expand All @@ -260,7 +260,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle,
recurse(dict[key], key);
}
node.arity = GET_SIZE<py::dict>(dict);
if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] {
if (node.kind == PyTreeKind::DEFAULT_DICT) [[unlikely]] {
node.node_data =
py::make_tuple(py::getattr(handle, "default_factory"), std::move(keys));
} else [[likely]] { // NOLINT
Expand All @@ -269,7 +269,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle,
break;
}

case PyTreeKind::NamedTuple: {
case PyTreeKind::NAMED_TUPLE: {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = GET_SIZE<py::tuple>(tuple);
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
Expand All @@ -279,7 +279,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle,
break;
}

case PyTreeKind::Deque: {
case PyTreeKind::DEQUE: {
py::list list = handle.cast<py::list>();
node.arity = GET_SIZE<py::list>(list);
node.node_data = py::getattr(handle, "maxlen");
Expand All @@ -289,7 +289,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle,
break;
}

case PyTreeKind::Custom: {
case PyTreeKind::CUSTOM: {
found_custom = true;
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
const size_t num_out = GET_SIZE<py::tuple>(out);
Expand Down Expand Up @@ -398,18 +398,18 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
++it;

switch (node.kind) {
case PyTreeKind::None:
case PyTreeKind::NONE:
break;

case PyTreeKind::Leaf:
case PyTreeKind::LEAF:
if (leaf < 0) [[unlikely]] {
throw std::logic_error("Leaf count mismatch.");
}
leaves[leaf] = py::reinterpret_borrow<py::object>(object);
--leaf;
break;

case PyTreeKind::Tuple: {
case PyTreeKind::TUPLE: {
AssertExact<py::tuple>(object);
py::tuple tuple = py::reinterpret_borrow<py::tuple>(object);
if ((ssize_t)GET_SIZE<py::tuple>(tuple) != node.arity) [[unlikely]] {
Expand All @@ -425,7 +425,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
break;
}

case PyTreeKind::List: {
case PyTreeKind::LIST: {
AssertExact<py::list>(object);
py::list list = py::reinterpret_borrow<py::list>(object);
if ((ssize_t)GET_SIZE<py::list>(list) != node.arity) [[unlikely]] {
Expand All @@ -441,16 +441,16 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
break;
}

case PyTreeKind::Dict:
case PyTreeKind::OrderedDict: {
if (node.kind == PyTreeKind::OrderedDict) [[unlikely]] {
case PyTreeKind::DICT:
case PyTreeKind::ORDERED_DICT: {
if (node.kind == PyTreeKind::ORDERED_DICT) [[unlikely]] {
AssertExactOrderedDict(object);
} else [[likely]] { // NOLINT
AssertExact<py::dict>(object);
}
py::dict dict = py::reinterpret_borrow<py::dict>(object);
py::list keys;
if (node.kind == PyTreeKind::OrderedDict) [[unlikely]] {
if (node.kind == PyTreeKind::ORDERED_DICT) [[unlikely]] {
keys = DictKeys(dict);
} else [[likely]] { // NOLINT
keys = SortedDictKeys(dict);
Expand All @@ -467,7 +467,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
break;
}

case PyTreeKind::NamedTuple: {
case PyTreeKind::NAMED_TUPLE: {
AssertExactNamedTuple(object);
py::tuple tuple = py::reinterpret_borrow<py::tuple>(object);
if ((ssize_t)GET_SIZE<py::tuple>(tuple) != node.arity) [[unlikely]] {
Expand All @@ -489,7 +489,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
break;
}

case PyTreeKind::DefaultDict: {
case PyTreeKind::DEFAULT_DICT: {
AssertExactDefaultDict(object);
py::dict dict = py::reinterpret_borrow<py::dict>(object);
py::list keys = SortedDictKeys(dict);
Expand All @@ -515,7 +515,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
break;
}

case PyTreeKind::Deque: {
case PyTreeKind::DEQUE: {
AssertExactDeque(object);
py::list list = py::cast<py::list>(object);
if ((ssize_t)GET_SIZE<py::list>(list) != node.arity) [[unlikely]] {
Expand All @@ -531,7 +531,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
break;
}

case PyTreeKind::Custom: {
case PyTreeKind::CUSTOM: {
const PyTreeTypeRegistry::Registration* registration;
if (m_none_is_leaf) [[unlikely]] {
registration =
Expand Down Expand Up @@ -595,7 +595,7 @@ template <bool NoneIsLeaf>
const std::string& registry_namespace) {
const PyTreeTypeRegistry::Registration* custom;
for (const py::handle& h : iterable) {
if (GetKind<NoneIsLeaf>(h, &custom, registry_namespace) != PyTreeKind::Leaf) [[unlikely]] {
if (GetKind<NoneIsLeaf>(h, &custom, registry_namespace) != PyTreeKind::LEAF) [[unlikely]] {
return false;
}
}
Expand Down
20 changes: 10 additions & 10 deletions src/treespec/traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ py::object PyTreeSpec::Walk(const py::function& f_node,
const bool f_leaf_identity = f_leaf.is_none();
for (const Node& node : m_traversal) {
switch (node.kind) {
case PyTreeKind::Leaf: {
case PyTreeKind::LEAF: {
if (it == leaves.end()) [[unlikely]] {
throw std::invalid_argument("Too few leaves for PyTreeSpec");
}
Expand All @@ -42,15 +42,15 @@ py::object PyTreeSpec::Walk(const py::function& f_node,
break;
}

case PyTreeKind::None:
case PyTreeKind::Tuple:
case PyTreeKind::NamedTuple:
case PyTreeKind::List:
case PyTreeKind::Dict:
case PyTreeKind::OrderedDict:
case PyTreeKind::DefaultDict:
case PyTreeKind::Deque:
case PyTreeKind::Custom: {
case PyTreeKind::NONE:
case PyTreeKind::TUPLE:
case PyTreeKind::NAMED_TUPLE:
case PyTreeKind::LIST:
case PyTreeKind::DICT:
case PyTreeKind::ORDERED_DICT:
case PyTreeKind::DEFAULT_DICT:
case PyTreeKind::DEQUE:
case PyTreeKind::CUSTOM: {
if ((ssize_t)agenda.size() < node.arity) [[unlikely]] {
throw std::logic_error("Too few elements for custom type.");
}
Expand Down
Loading

0 comments on commit c70cd6a

Please sign in to comment.