Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ops): add tree_iter function #130

Merged
merged 2 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add `tree_iter` function by [@XuehaiPan](https://github.com/XuehaiPan) in [#130](https://github.com/metaopt/optree/pull/130).
- Add API to unregister node type in the registry by [@XuehaiPan](https://github.com/XuehaiPan) in [#124](https://github.com/metaopt/optree/pull/124).
- Add tree map functions with transposed outputs `tree_transpose_map` and `tree_transpose_map_with_path` by [@XuehaiPan](https://github.com/XuehaiPan) in [#127](https://github.com/metaopt/optree/pull/127).
- Add static constructors to create `PyTreeSpec` instances by [@XuehaiPan](https://github.com/XuehaiPan) in [#120](https://github.com/metaopt/optree/pull/120).
Expand Down
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Tree Manipulation Functions
tree_flatten
tree_flatten_with_path
tree_unflatten
tree_iter
tree_leaves
tree_structure
tree_paths
Expand All @@ -51,6 +52,7 @@ Tree Manipulation Functions
.. autofunction:: tree_flatten
.. autofunction:: tree_flatten_with_path
.. autofunction:: tree_unflatten
.. autofunction:: tree_iter
.. autofunction:: tree_leaves
.. autofunction:: tree_structure
.. autofunction:: tree_paths
Expand Down
6 changes: 6 additions & 0 deletions include/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ class PyTreeTypeRegistry {
template <bool NoneIsLeaf>
static RegistrationPtr Lookup(const py::object &cls, const std::string &registry_namespace);

// Compute the node kind of a given Python object.
template <bool NoneIsLeaf>
static PyTreeKind GetKind(const py::handle &handle,
RegistrationPtr &custom, // NOLINT[runtime/references]
const std::string &registry_namespace);

private:
template <bool NoneIsLeaf>
static PyTreeTypeRegistry *Singleton();
Expand Down
89 changes: 60 additions & 29 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ limitations under the License.
#include <thread> // std::thread::id // NOLINT[build/c++11]
#include <tuple> // std::tuple
#include <unordered_set> // std::unordered_set
#include <utility> // std::pair
#include <utility> // std::pair, std::make_pair
#include <vector> // std::vector

#include "include/registry.h"
Expand All @@ -40,6 +40,28 @@ using ssize_t = py::ssize_t;
// The maximum depth of a pytree.
constexpr ssize_t MAX_RECURSION_DEPTH = 2000;

// Test whether the given object is a leaf node.
bool IsLeaf(const py::object &object,
const std::optional<py::function> &leaf_predicate,
const bool &none_is_leaf = false,
const std::string &registry_namespace = "");

// Test whether all elements in the given iterable are all leaves.
bool AllLeaves(const py::iterable &iterable,
const std::optional<py::function> &leaf_predicate,
const bool &none_is_leaf = false,
const std::string &registry_namespace = "");

template <bool NoneIsLeaf>
bool IsLeafImpl(const py::handle &handle,
const std::optional<py::function> &leaf_predicate,
const std::string &registry_namespace);

template <bool NoneIsLeaf>
bool AllLeavesImpl(const py::iterable &iterable,
const std::optional<py::function> &leaf_predicate,
const std::string &registry_namespace);

// A PyTreeSpec describes the tree structure of a PyTree. A PyTree is a tree of Python values, where
// the interior nodes are tuples, lists, dictionaries, or user-defined containers, and the leaves
// are other objects.
Expand Down Expand Up @@ -164,18 +186,6 @@ class PyTreeSpec {
const bool &none_is_leaf = false,
const std::string &registry_namespace = "");

// Test whether the given object is a leaf node.
static bool ObjectIsLeaf(const py::object &object,
const std::optional<py::function> &leaf_predicate,
const bool &none_is_leaf = false,
const std::string &registry_namespace = "");

// Test whether all elements in the given iterable are all leaves.
static bool AllLeaves(const py::iterable &iterable,
const std::optional<py::function> &leaf_predicate,
const bool &none_is_leaf = false,
const std::string &registry_namespace = "");

private:
using RegistrationPtr = PyTreeTypeRegistry::RegistrationPtr;

Expand Down Expand Up @@ -232,12 +242,6 @@ class PyTreeSpec {
const py::object *children,
const size_t &num_children);

// Compute the node kind of a given Python object.
template <bool NoneIsLeaf>
static PyTreeKind GetKind(const py::handle &handle,
RegistrationPtr &custom, // NOLINT[runtime/references]
const std::string &registry_namespace);

// Recursive helper used to implement Flatten().
bool FlattenInto(const py::handle &handle,
std::vector<py::object> &leaves, // NOLINT[runtime/references]
Expand Down Expand Up @@ -296,16 +300,6 @@ class PyTreeSpec {
static std::unique_ptr<PyTreeSpec> MakeFromCollectionImpl(const py::handle &handle,
std::string registry_namespace);

template <bool NoneIsLeaf>
static bool ObjectIsLeafImpl(const py::handle &handle,
const std::optional<py::function> &leaf_predicate,
const std::string &registry_namespace);

template <bool NoneIsLeaf>
static bool AllLeavesImpl(const py::iterable &iterable,
const std::optional<py::function> &leaf_predicate,
const std::string &registry_namespace);

class ThreadIndentTypeHash {
public:
using is_transparent = void;
Expand All @@ -323,4 +317,41 @@ class PyTreeSpec {
sm_hash_running{};
};

class PyTreeIter {
public:
PyTreeIter(const py::object &tree,
const std::optional<py::function> &leaf_predicate,
bool none_is_leaf,
std::string registry_namespace)
: m_agenda({std::make_pair(tree, 0)}),
m_leaf_predicate(leaf_predicate),
m_none_is_leaf(none_is_leaf),
m_namespace(std::move(registry_namespace)){};

PyTreeIter() = delete;

~PyTreeIter() = default;

PyTreeIter(const PyTreeIter &) = delete;

PyTreeIter operator=(const PyTreeIter &) = delete;

PyTreeIter(PyTreeIter &&) = default;

PyTreeIter &operator=(PyTreeIter &&) = default;

[[nodiscard]] PyTreeIter &Iter() { return *this; }

[[nodiscard]] py::object Next();

private:
std::vector<std::pair<py::object, ssize_t>> m_agenda;
std::optional<py::function> m_leaf_predicate;
bool m_none_is_leaf;
std::string m_namespace;

template <bool NoneIsLeaf>
[[nodiscard]] py::object NextImpl();
};

} // namespace optree
13 changes: 12 additions & 1 deletion optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import builtins
import enum
from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, Iterator
from typing import Any

from optree.typing import CustomTreeNode, FlattenFunc, MetaData, PyTree, T, U, UnflattenFunc
Expand Down Expand Up @@ -123,6 +123,17 @@ class PyTreeSpec:
def __hash__(self) -> int: ...
def __len__(self) -> int: ...

class PyTreeIter(Iterator[T]):
def __init__(
self,
tree: PyTree[T],
leaf_predicate: Callable[[T], bool] | None = None,
node_is_leaf: bool = False,
namespace: str = '',
) -> None: ...
def __iter__(self) -> PyTreeIter[T]: ...
def __next__(self) -> T: ...

def register_node(
cls: type[CustomTreeNode[T]],
flatten_func: FlattenFunc,
Expand Down
2 changes: 2 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
tree_flatten_one_level,
tree_flatten_with_path,
tree_is_leaf,
tree_iter,
tree_leaves,
tree_map,
tree_map_,
Expand Down Expand Up @@ -106,6 +107,7 @@
'tree_flatten',
'tree_flatten_with_path',
'tree_unflatten',
'tree_iter',
'tree_leaves',
'tree_structure',
'tree_paths',
Expand Down
50 changes: 46 additions & 4 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
'tree_flatten',
'tree_flatten_with_path',
'tree_unflatten',
'tree_iter',
'tree_leaves',
'tree_structure',
'tree_paths',
Expand Down Expand Up @@ -129,7 +130,7 @@ def tree_flatten(
) -> tuple[list[T], PyTreeSpec]:
"""Flatten a pytree.

See also :func:`tree_flatten_with_path`.
See also :func:`tree_flatten_with_path` and :func:`tree_unflatten`.

The flattening order (i.e., the order of elements in the output list) is deterministic,
corresponding to a left-to-right depth-first tree traversal.
Expand Down Expand Up @@ -283,6 +284,47 @@ def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[T]) -> PyTree[T]:
return treespec.unflatten(leaves)


def tree_iter(
tree: PyTree[T],
is_leaf: Callable[[T], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = '',
) -> Iterable[T]:
"""Get an iterator over the leaves of a pytree.

See also :func:`tree_flatten` and :func:`tree_leaves`.

>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> list(tree_iter(tree))
[1, 2, 3, 4, 5]
>>> list(tree_iter(tree, none_is_leaf=True))
[1, 2, 3, 4, None, 5]
>>> list(tree_iter(1))
[1]
>>> list(tree_iter(None))
[]
>>> list(tree_iter(None, none_is_leaf=True))
[None]

Args:
tree (pytree): A pytree to iterate over.
is_leaf (callable, optional): An optionally specified function that will be called at each
flattening step. It should return a boolean, with :data:`True` stopping the traversal
and the whole subtree being treated as a leaf, and :data:`False` indicating the
flattening should traverse the current object.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list. (default: :data:`False`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`''`, i.e., the global namespace)

Returns:
An iterator over the leaf values.
"""
return _C.PyTreeIter(tree, is_leaf, none_is_leaf, namespace)


def tree_leaves(
tree: PyTree[T],
is_leaf: Callable[[T], bool] | None = None,
Expand All @@ -292,7 +334,7 @@ def tree_leaves(
) -> list[T]:
"""Get the leaves of a pytree.

See also :func:`tree_flatten`.
See also :func:`tree_flatten` and :func:`tree_iter`.

>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree_leaves(tree)
Expand Down Expand Up @@ -1827,7 +1869,7 @@ def tree_all(
Otherwise, :data:`False`.
"""
return all(
tree_leaves(
tree_iter(
tree, # type: ignore[arg-type]
is_leaf=is_leaf, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
Expand Down Expand Up @@ -1878,7 +1920,7 @@ def tree_any(
empty, return :data:`False`.
"""
return any(
tree_leaves(
tree_iter(
tree, # type: ignore[arg-type]
is_leaf=is_leaf, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
Expand Down
35 changes: 29 additions & 6 deletions src/optree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <pybind11/stl.h>

#include <optional> // std::nullopt
#include <string> // std::string

#include "include/exceptions.h"
#include "include/registry.h"
Expand Down Expand Up @@ -67,14 +68,14 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
py::arg("none_is_leaf") = false,
py::arg("namespace") = "")
.def("is_leaf",
&PyTreeSpec::ObjectIsLeaf,
&IsLeaf,
"Test whether the given object is a leaf node.",
py::arg("obj"),
py::arg("leaf_predicate") = std::nullopt,
py::arg("none_is_leaf") = false,
py::arg("namespace") = "")
.def("all_leaves",
&PyTreeSpec::AllLeaves,
&AllLeaves,
"Test whether all elements in the given iterable are all leaves.",
py::arg("iterable"),
py::arg("leaf_predicate") = std::nullopt,
Expand Down Expand Up @@ -259,22 +260,44 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
"Serialization support for PyTreeSpec.",
py::arg("state"));

auto PyTreeIterTypeObject =
py::class_<PyTreeIter>(mod, "PyTreeIter", "Iterator over the leaves of a pytree.");
reinterpret_cast<PyTypeObject*>(PyTreeIterTypeObject.ptr())->tp_name = "optree.PyTreeIter";
py::setattr(PyTreeIterTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree));

PyTreeIterTypeObject
.def(py::init<py::object, std::optional<py::function>, bool, std::string>(),
"Create a new iterator over the leaves of a pytree.",
py::arg("tree"),
py::arg("leaf_predicate") = std::nullopt,
py::arg("none_is_leaf") = false,
py::arg("namespace") = "")
.def("__iter__", &PyTreeIter::Iter, "Return the iterator object itself.")
.def("__next__", &PyTreeIter::Next, "Return the next leaf in the pytree.");

#ifdef Py_TPFLAGS_IMMUTABLETYPE
reinterpret_cast<PyTypeObject*>(PyTreeKindTypeObject.ptr())->tp_flags |=
Py_TPFLAGS_IMMUTABLETYPE;
reinterpret_cast<PyTypeObject*>(PyTreeSpecTypeObject.ptr())->tp_flags |=
Py_TPFLAGS_IMMUTABLETYPE;
reinterpret_cast<PyTypeObject*>(PyTreeKindTypeObject.ptr())->tp_flags |=
reinterpret_cast<PyTypeObject*>(PyTreeIterTypeObject.ptr())->tp_flags |=
Py_TPFLAGS_IMMUTABLETYPE;
reinterpret_cast<PyTypeObject*>(PyTreeSpecTypeObject.ptr())->tp_flags &= ~Py_TPFLAGS_READY;
reinterpret_cast<PyTypeObject*>(PyTreeKindTypeObject.ptr())->tp_flags &= ~Py_TPFLAGS_READY;
reinterpret_cast<PyTypeObject*>(PyTreeSpecTypeObject.ptr())->tp_flags &= ~Py_TPFLAGS_READY;
reinterpret_cast<PyTypeObject*>(PyTreeIterTypeObject.ptr())->tp_flags &= ~Py_TPFLAGS_READY;
#endif

if (PyType_Ready(reinterpret_cast<PyTypeObject*>(PyTreeKindTypeObject.ptr())) < 0)
[[unlikely]] {
INTERNAL_ERROR("`PyType_Ready(&PyTreeKind_Type)` failed.");
}
if (PyType_Ready(reinterpret_cast<PyTypeObject*>(PyTreeSpecTypeObject.ptr())) < 0)
[[unlikely]] {
INTERNAL_ERROR("`PyType_Ready(&PyTreeSpec_Type)` failed.");
}
if (PyType_Ready(reinterpret_cast<PyTypeObject*>(PyTreeKindTypeObject.ptr())) < 0)
if (PyType_Ready(reinterpret_cast<PyTypeObject*>(PyTreeIterTypeObject.ptr())) < 0)
[[unlikely]] {
INTERNAL_ERROR("`PyType_Ready(&PyTreeKind_Type)` failed.");
INTERNAL_ERROR("`PyType_Ready(&PyTreeIter_Type)` failed.");
}
}

Expand Down
Loading
Loading