Skip to content

Commit

Permalink
style: reuse temporary variables
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Jun 22, 2024
1 parent 165e83c commit 3b33775
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions src/optree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
.value("DEFAULTDICT", PyTreeKind::DefaultDict, "A collections.defaultdict.")
.value("DEQUE", PyTreeKind::Deque, "A collections.deque.")
.value("STRUCTSEQUENCE", PyTreeKind::StructSequence, "A PyStructSequence.");
reinterpret_cast<PyTypeObject*>(PyTreeKindTypeObject.ptr())->tp_name = "optree.PyTreeKind";
auto* PyTreeKind_Type = reinterpret_cast<PyTypeObject*>(PyTreeKindTypeObject.ptr());
PyTreeKind_Type->tp_name = "optree.PyTreeKind";
py::setattr(PyTreeKindTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree));

auto PyTreeSpecTypeObject = py::class_<PyTreeSpec>(
Expand Down Expand Up @@ -189,7 +190,8 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
}),
// NOLINTEND[readability-function-cognitive-complexity,cppcoreguidelines-avoid-do-while]
py::module_local());
reinterpret_cast<PyTypeObject*>(PyTreeSpecTypeObject.ptr())->tp_name = "optree.PyTreeSpec";
auto* PyTreeSpec_Type = reinterpret_cast<PyTypeObject*>(PyTreeSpecTypeObject.ptr());
PyTreeSpec_Type->tp_name = "optree.PyTreeSpec";
py::setattr(PyTreeSpecTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree));

PyTreeSpecTypeObject
Expand Down Expand Up @@ -339,7 +341,8 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
}),
// NOLINTEND[readability-function-cognitive-complexity,cppcoreguidelines-avoid-do-while]
py::module_local());
reinterpret_cast<PyTypeObject*>(PyTreeIterTypeObject.ptr())->tp_name = "optree.PyTreeIter";
auto* PyTreeIter_Type = reinterpret_cast<PyTypeObject*>(PyTreeIterTypeObject.ptr());
PyTreeIter_Type->tp_name = "optree.PyTreeIter";
py::setattr(PyTreeIterTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree));

PyTreeIterTypeObject
Expand All @@ -353,27 +356,21 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
.def("__next__", &PyTreeIter::Next, "Return the next leaf in the pytree.");

#ifdef Py_TPFLAGS_IMMUTABLETYPE
reinterpret_cast<PyTypeObject*>(PyTreeKindTypeObject.ptr())->tp_flags |=
Py_TPFLAGS_IMMUTABLETYPE;
reinterpret_cast<PyTypeObject*>(PyTreeSpecTypeObject.ptr())->tp_flags |=
Py_TPFLAGS_IMMUTABLETYPE;
reinterpret_cast<PyTypeObject*>(PyTreeIterTypeObject.ptr())->tp_flags |=
Py_TPFLAGS_IMMUTABLETYPE;
reinterpret_cast<PyTypeObject*>(PyTreeKindTypeObject.ptr())->tp_flags &= ~Py_TPFLAGS_READY;
reinterpret_cast<PyTypeObject*>(PyTreeSpecTypeObject.ptr())->tp_flags &= ~Py_TPFLAGS_READY;
reinterpret_cast<PyTypeObject*>(PyTreeIterTypeObject.ptr())->tp_flags &= ~Py_TPFLAGS_READY;
PyTreeKind_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
PyTreeSpec_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
PyTreeIter_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
PyTreeKind_Type->tp_flags &= ~Py_TPFLAGS_READY;
PyTreeSpec_Type->tp_flags &= ~Py_TPFLAGS_READY;
PyTreeIter_Type->tp_flags &= ~Py_TPFLAGS_READY;
#endif

if (PyType_Ready(reinterpret_cast<PyTypeObject*>(PyTreeKindTypeObject.ptr())) < 0)
[[unlikely]] {
if (PyType_Ready(PyTreeKind_Type) < 0) [[unlikely]] {
INTERNAL_ERROR("`PyType_Ready(&PyTreeKind_Type)` failed.");
}
if (PyType_Ready(reinterpret_cast<PyTypeObject*>(PyTreeSpecTypeObject.ptr())) < 0)
[[unlikely]] {
if (PyType_Ready(PyTreeSpec_Type) < 0) [[unlikely]] {
INTERNAL_ERROR("`PyType_Ready(&PyTreeSpec_Type)` failed.");
}
if (PyType_Ready(reinterpret_cast<PyTypeObject*>(PyTreeIterTypeObject.ptr())) < 0)
[[unlikely]] {
if (PyType_Ready(PyTreeIter_Type) < 0) [[unlikely]] {
INTERNAL_ERROR("`PyType_Ready(&PyTreeIter_Type)` failed.");
}
}
Expand Down

0 comments on commit 3b33775

Please sign in to comment.