Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras/api/_tf_keras/keras/tree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
since your modifications would be overwritten.
"""

from keras.src.tree.tree_api import assert_same_paths
from keras.src.tree.tree_api import assert_same_structure
from keras.src.tree.tree_api import flatten
from keras.src.tree.tree_api import flatten_with_path
from keras.src.tree.tree_api import is_nested
from keras.src.tree.tree_api import lists_to_tuples
from keras.src.tree.tree_api import map_shape_structure
Expand Down
2 changes: 2 additions & 0 deletions keras/api/tree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
since your modifications would be overwritten.
"""

from keras.src.tree.tree_api import assert_same_paths
from keras.src.tree.tree_api import assert_same_structure
from keras.src.tree.tree_api import flatten
from keras.src.tree.tree_api import flatten_with_path
from keras.src.tree.tree_api import is_nested
from keras.src.tree.tree_api import lists_to_tuples
from keras.src.tree.tree_api import map_shape_structure
Expand Down
2 changes: 2 additions & 0 deletions keras/src/tree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from keras.src.tree.tree_api import assert_same_paths
from keras.src.tree.tree_api import assert_same_structure
from keras.src.tree.tree_api import flatten
from keras.src.tree.tree_api import flatten_with_path
from keras.src.tree.tree_api import is_nested
from keras.src.tree.tree_api import lists_to_tuples
from keras.src.tree.tree_api import map_shape_structure
Expand Down
19 changes: 19 additions & 0 deletions keras/src/tree/dmtree_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def flatten(structure):
return dmtree.flatten(structure)


def flatten_with_path(structure):
return dmtree.flatten_with_path(structure)


def map_structure(func, *structures):
return dmtree.map_structure(func, *structures)

Expand All @@ -29,6 +33,21 @@ def assert_same_structure(a, b, check_types=True):
return dmtree.assert_same_structure(a, b, check_types=check_types)


def assert_same_paths(a, b):
a_paths = set([path for path, leaf in dmtree.flatten_with_path(a)])
b_paths = set([path for path, leaf in dmtree.flatten_with_path(b)])

if a_paths != b_paths:
msg = "`a` and `b` don't have the same paths."
a_diff = a_paths.difference(b_paths)
if a_diff:
msg += f"\nPaths in `a` missing in `b`:\n{a_diff}"
b_diff = b_paths.difference(a_paths)
if b_diff:
msg += f"\nPaths in `b` missing in `a`:\n{b_diff}"
raise ValueError(msg)


def pack_sequence_as(structure, flat_sequence, sequence_fn=None):
is_nested_fn = dmtree.is_nested
sequence_fn = sequence_fn or dmtree._sequence_like
Expand Down
23 changes: 23 additions & 0 deletions keras/src/tree/optree_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ def flatten(structure):
return leaves


def flatten_with_path(structure):
paths, leaves, _ = optree.tree_flatten_with_path(
structure, none_is_leaf=True, namespace="keras"
)
leaves_with_path = list(zip(paths, leaves))
return leaves_with_path


def map_structure(func, *structures):
if not callable(func):
raise TypeError(f"`func` must be callable. Received: func={func}")
Expand Down Expand Up @@ -125,6 +133,21 @@ def assert_same_structure(a, b, check_types=True):
)


def assert_same_paths(a, b):
a_paths = set(optree.tree_paths(a, none_is_leaf=True, namespace="keras"))
b_paths = set(optree.tree_paths(b, none_is_leaf=True, namespace="keras"))

if a_paths != b_paths:
msg = "`a` and `b` don't have the same paths."
a_diff = a_paths.difference(b_paths)
if a_diff:
msg += f"\nPaths in `a` missing in `b`:\n{a_diff}"
b_diff = b_paths.difference(a_paths)
if b_diff:
msg += f"\nPaths in `b` missing in `a`:\n{b_diff}"
raise ValueError(msg)


def pack_sequence_as(structure, flat_sequence, sequence_fn=None):
sequence_fn = sequence_fn or _sequence_like

Expand Down
52 changes: 52 additions & 0 deletions keras/src/tree/tree_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,32 @@ def flatten(structure):
return tree_impl.flatten(structure)


@keras_export("keras.tree.flatten_with_path")
def flatten_with_path(structure):
"""Flattens a possibly nested structure into a list.

This is a variant of flattens() which produces a
list of pairs: `(path, item)`. A path is a tuple of indices and/or keys
which uniquely identifies the position of the corresponding item.

Dictionaries with non-sortable keys cannot be flattened.

Examples:

>>> keras.flatten_with_path([{"foo": 42}])
[((0, 'foo'), 42)]


Args:
structure: An arbitrarily nested structure.

Returns:
A list of `(path, item)` pairs corresponding to the flattened
version of the input `structure`.
"""
return tree_impl.flatten_with_path(structure)


@keras_export("keras.tree.map_structure")
def map_structure(func, *structures):
"""Maps `func` through given structures.
Expand Down Expand Up @@ -205,6 +231,32 @@ def assert_same_structure(a, b, check_types=True):
return tree_impl.assert_same_structure(a, b, check_types=check_types)


@keras_export("keras.tree.assert_same_paths")
def assert_same_paths(a, b):
"""Asserts that two structures have identical paths in their tree structure.

This function verifies that two nested structures have the same paths.
Unlike `assert_same_structure`, this function only checks the paths
and ignores the nodes' types.

Examples:
>>> keras.tree.assert_same_paths([0, 1], (2, 3))
>>> Point1 = collections.namedtuple('Point1', ['x', 'y'])
>>> Point2 = collections.namedtuple('Point2', ['x', 'y'])
>>> keras.tree.assert_same_paths(Point1(0, 1), Point2(2, 3))

Args:
a: an arbitrarily nested structure.
b: an arbitrarily nested structure.

Raises:
`ValueError`: If the paths in structure `a` don't match the paths
in structure `b`. The error message will include the specific
paths that differ.
"""
return tree_impl.assert_same_paths(a, b)


@keras_export("keras.tree.pack_sequence_as")
def pack_sequence_as(structure, flat_sequence, sequence_fn=None):
"""Returns a given flattened sequence packed into a given structure.
Expand Down
69 changes: 68 additions & 1 deletion keras/src/tree/tree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def test_is_nested(self, tree_impl, is_optree):

def test_flatten(self, tree_impl, is_optree):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]

self.assertEqual(
tree_impl.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]
Expand All @@ -68,6 +67,48 @@ def test_flatten(self, tree_impl, is_optree):
self.assertEqual([5], tree_impl.flatten(5))
self.assertEqual([np.array([5])], tree_impl.flatten(np.array([5])))

def test_flatten_with_path(self, tree_impl, is_optree):
structure = {"b": (0, 1), "a": [2, 3]}
flat_with_path = tree_impl.flatten_with_path(structure)

self.assertEqual(
tree_impl.flatten(flat_with_path),
tree_impl.flatten(
[(("a", 0), 2), (("a", 1), 3), (("b", 0), 0), (("b", 1), 1)]
),
)
point = collections.namedtuple("Point", ["x", "y", "z"])
structure = point(x=(0, 1), y=[2, 3], z={"a": 4})
flat_with_path = tree_impl.flatten_with_path(structure)

if is_optree:
# optree doesn't return namedtuple's field name, but the index
self.assertEqual(
tree_impl.flatten(flat_with_path),
tree_impl.flatten(
[
((0, 0), 0),
((0, 1), 1),
((1, 0), 2),
((1, 1), 3),
((2, "a"), 4),
]
),
)
else:
self.assertEqual(
tree_impl.flatten(flat_with_path),
tree_impl.flatten(
[
(("x", 0), 0),
(("x", 1), 1),
(("y", 0), 2),
(("y", 1), 3),
(("z", "a"), 4),
]
),
)

def test_flatten_dict_order(self, tree_impl, is_optree):
ordered = collections.OrderedDict(
[("d", 3), ("b", 1), ("a", 0), ("c", 2)]
Expand Down Expand Up @@ -225,6 +266,32 @@ def test_assert_same_structure(self, tree_impl, is_optree):
STRUCTURE1, structure1_list, check_types=False
)

def test_assert_same_paths(self, tree_impl, is_optree):
assertion_message = "don't have the same paths"

tree_impl.assert_same_paths([0, 1], (0, 1))
Point1 = collections.namedtuple("Point1", ["x", "y"])
Point2 = collections.namedtuple("Point2", ["x", "y"])
tree_impl.assert_same_paths(Point1(0, 1), Point2(0, 1))

with self.assertRaisesRegex(ValueError, assertion_message):
tree_impl.assert_same_paths(
STRUCTURE1, STRUCTURE_DIFFERENT_NUM_ELEMENTS
)
with self.assertRaisesRegex(ValueError, assertion_message):
tree_impl.assert_same_paths([0, 1], np.array([0, 1]))
with self.assertRaisesRegex(ValueError, assertion_message):
tree_impl.assert_same_paths(0, [0, 1])
with self.assertRaisesRegex(ValueError, assertion_message):
tree_impl.assert_same_paths(STRUCTURE1, STRUCTURE_DIFFERENT_NESTING)
with self.assertRaisesRegex(ValueError, assertion_message):
tree_impl.assert_same_paths([[3], 4], [3, [4]])
with self.assertRaisesRegex(ValueError, assertion_message):
tree_impl.assert_same_paths({"a": 1}, {"b": 1})
structure1_list = [[[1, 2], 3], 4, [5, 6]]
tree_impl.assert_same_paths(STRUCTURE1, structure1_list)
tree_impl.assert_same_paths(STRUCTURE1, STRUCTURE2)

def test_pack_sequence_as(self, tree_impl, is_optree):
structure = {"key3": "", "key1": "", "key2": ""}
flat_sequence = ["value1", "value2", "value3"]
Expand Down