Skip to content

Commit 05e1c6e

Browse files
nicolaspiwang-xianghao
authored andcommitted
Add tree.flatten_with_path and tree.assert_same_paths methods. (keras-team#20431)
* Add `tree.flatten_with_path` and `tree.assert_same_paths` methods. * Add methods in `__init__.py` * Fix api generated files.
1 parent 9608ccd commit 05e1c6e

File tree

7 files changed

+168
-1
lines changed

7 files changed

+168
-1
lines changed

keras/api/_tf_keras/keras/tree/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
since your modifications would be overwritten.
55
"""
66

7+
from keras.src.tree.tree_api import assert_same_paths
78
from keras.src.tree.tree_api import assert_same_structure
89
from keras.src.tree.tree_api import flatten
10+
from keras.src.tree.tree_api import flatten_with_path
911
from keras.src.tree.tree_api import is_nested
1012
from keras.src.tree.tree_api import lists_to_tuples
1113
from keras.src.tree.tree_api import map_shape_structure

keras/api/tree/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
since your modifications would be overwritten.
55
"""
66

7+
from keras.src.tree.tree_api import assert_same_paths
78
from keras.src.tree.tree_api import assert_same_structure
89
from keras.src.tree.tree_api import flatten
10+
from keras.src.tree.tree_api import flatten_with_path
911
from keras.src.tree.tree_api import is_nested
1012
from keras.src.tree.tree_api import lists_to_tuples
1113
from keras.src.tree.tree_api import map_shape_structure

keras/src/tree/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from keras.src.tree.tree_api import assert_same_paths
12
from keras.src.tree.tree_api import assert_same_structure
23
from keras.src.tree.tree_api import flatten
4+
from keras.src.tree.tree_api import flatten_with_path
35
from keras.src.tree.tree_api import is_nested
46
from keras.src.tree.tree_api import lists_to_tuples
57
from keras.src.tree.tree_api import map_shape_structure

keras/src/tree/dmtree_impl.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ def flatten(structure):
1717
return dmtree.flatten(structure)
1818

1919

20+
def flatten_with_path(structure):
21+
return dmtree.flatten_with_path(structure)
22+
23+
2024
def map_structure(func, *structures):
2125
return dmtree.map_structure(func, *structures)
2226

@@ -29,6 +33,21 @@ def assert_same_structure(a, b, check_types=True):
2933
return dmtree.assert_same_structure(a, b, check_types=check_types)
3034

3135

36+
def assert_same_paths(a, b):
37+
a_paths = set([path for path, leaf in dmtree.flatten_with_path(a)])
38+
b_paths = set([path for path, leaf in dmtree.flatten_with_path(b)])
39+
40+
if a_paths != b_paths:
41+
msg = "`a` and `b` don't have the same paths."
42+
a_diff = a_paths.difference(b_paths)
43+
if a_diff:
44+
msg += f"\nPaths in `a` missing in `b`:\n{a_diff}"
45+
b_diff = b_paths.difference(a_paths)
46+
if b_diff:
47+
msg += f"\nPaths in `b` missing in `a`:\n{b_diff}"
48+
raise ValueError(msg)
49+
50+
3251
def pack_sequence_as(structure, flat_sequence, sequence_fn=None):
3352
is_nested_fn = dmtree.is_nested
3453
sequence_fn = sequence_fn or dmtree._sequence_like

keras/src/tree/optree_impl.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ def flatten(structure):
8080
return leaves
8181

8282

83+
def flatten_with_path(structure):
84+
paths, leaves, _ = optree.tree_flatten_with_path(
85+
structure, none_is_leaf=True, namespace="keras"
86+
)
87+
leaves_with_path = list(zip(paths, leaves))
88+
return leaves_with_path
89+
90+
8391
def map_structure(func, *structures):
8492
if not callable(func):
8593
raise TypeError(f"`func` must be callable. Received: func={func}")
@@ -125,6 +133,21 @@ def assert_same_structure(a, b, check_types=True):
125133
)
126134

127135

136+
def assert_same_paths(a, b):
137+
a_paths = set(optree.tree_paths(a, none_is_leaf=True, namespace="keras"))
138+
b_paths = set(optree.tree_paths(b, none_is_leaf=True, namespace="keras"))
139+
140+
if a_paths != b_paths:
141+
msg = "`a` and `b` don't have the same paths."
142+
a_diff = a_paths.difference(b_paths)
143+
if a_diff:
144+
msg += f"\nPaths in `a` missing in `b`:\n{a_diff}"
145+
b_diff = b_paths.difference(a_paths)
146+
if b_diff:
147+
msg += f"\nPaths in `b` missing in `a`:\n{b_diff}"
148+
raise ValueError(msg)
149+
150+
128151
def pack_sequence_as(structure, flat_sequence, sequence_fn=None):
129152
sequence_fn = sequence_fn or _sequence_like
130153

keras/src/tree/tree_api.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,32 @@ def flatten(structure):
121121
return tree_impl.flatten(structure)
122122

123123

124+
@keras_export("keras.tree.flatten_with_path")
125+
def flatten_with_path(structure):
126+
"""Flattens a possibly nested structure into a list.
127+
128+
This is a variant of flattens() which produces a
129+
list of pairs: `(path, item)`. A path is a tuple of indices and/or keys
130+
which uniquely identifies the position of the corresponding item.
131+
132+
Dictionaries with non-sortable keys cannot be flattened.
133+
134+
Examples:
135+
136+
>>> keras.flatten_with_path([{"foo": 42}])
137+
[((0, 'foo'), 42)]
138+
139+
140+
Args:
141+
structure: An arbitrarily nested structure.
142+
143+
Returns:
144+
A list of `(path, item)` pairs corresponding to the flattened
145+
version of the input `structure`.
146+
"""
147+
return tree_impl.flatten_with_path(structure)
148+
149+
124150
@keras_export("keras.tree.map_structure")
125151
def map_structure(func, *structures):
126152
"""Maps `func` through given structures.
@@ -205,6 +231,32 @@ def assert_same_structure(a, b, check_types=True):
205231
return tree_impl.assert_same_structure(a, b, check_types=check_types)
206232

207233

234+
@keras_export("keras.tree.assert_same_paths")
235+
def assert_same_paths(a, b):
236+
"""Asserts that two structures have identical paths in their tree structure.
237+
238+
This function verifies that two nested structures have the same paths.
239+
Unlike `assert_same_structure`, this function only checks the paths
240+
and ignores the nodes' types.
241+
242+
Examples:
243+
>>> keras.tree.assert_same_paths([0, 1], (2, 3))
244+
>>> Point1 = collections.namedtuple('Point1', ['x', 'y'])
245+
>>> Point2 = collections.namedtuple('Point2', ['x', 'y'])
246+
>>> keras.tree.assert_same_paths(Point1(0, 1), Point2(2, 3))
247+
248+
Args:
249+
a: an arbitrarily nested structure.
250+
b: an arbitrarily nested structure.
251+
252+
Raises:
253+
`ValueError`: If the paths in structure `a` don't match the paths
254+
in structure `b`. The error message will include the specific
255+
paths that differ.
256+
"""
257+
return tree_impl.assert_same_paths(a, b)
258+
259+
208260
@keras_export("keras.tree.pack_sequence_as")
209261
def pack_sequence_as(structure, flat_sequence, sequence_fn=None):
210262
"""Returns a given flattened sequence packed into a given structure.

keras/src/tree/tree_test.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def test_is_nested(self, tree_impl, is_optree):
5555

5656
def test_flatten(self, tree_impl, is_optree):
5757
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
58-
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
5958

6059
self.assertEqual(
6160
tree_impl.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]
@@ -68,6 +67,48 @@ def test_flatten(self, tree_impl, is_optree):
6867
self.assertEqual([5], tree_impl.flatten(5))
6968
self.assertEqual([np.array([5])], tree_impl.flatten(np.array([5])))
7069

70+
def test_flatten_with_path(self, tree_impl, is_optree):
71+
structure = {"b": (0, 1), "a": [2, 3]}
72+
flat_with_path = tree_impl.flatten_with_path(structure)
73+
74+
self.assertEqual(
75+
tree_impl.flatten(flat_with_path),
76+
tree_impl.flatten(
77+
[(("a", 0), 2), (("a", 1), 3), (("b", 0), 0), (("b", 1), 1)]
78+
),
79+
)
80+
point = collections.namedtuple("Point", ["x", "y", "z"])
81+
structure = point(x=(0, 1), y=[2, 3], z={"a": 4})
82+
flat_with_path = tree_impl.flatten_with_path(structure)
83+
84+
if is_optree:
85+
# optree doesn't return namedtuple's field name, but the index
86+
self.assertEqual(
87+
tree_impl.flatten(flat_with_path),
88+
tree_impl.flatten(
89+
[
90+
((0, 0), 0),
91+
((0, 1), 1),
92+
((1, 0), 2),
93+
((1, 1), 3),
94+
((2, "a"), 4),
95+
]
96+
),
97+
)
98+
else:
99+
self.assertEqual(
100+
tree_impl.flatten(flat_with_path),
101+
tree_impl.flatten(
102+
[
103+
(("x", 0), 0),
104+
(("x", 1), 1),
105+
(("y", 0), 2),
106+
(("y", 1), 3),
107+
(("z", "a"), 4),
108+
]
109+
),
110+
)
111+
71112
def test_flatten_dict_order(self, tree_impl, is_optree):
72113
ordered = collections.OrderedDict(
73114
[("d", 3), ("b", 1), ("a", 0), ("c", 2)]
@@ -225,6 +266,32 @@ def test_assert_same_structure(self, tree_impl, is_optree):
225266
STRUCTURE1, structure1_list, check_types=False
226267
)
227268

269+
def test_assert_same_paths(self, tree_impl, is_optree):
270+
assertion_message = "don't have the same paths"
271+
272+
tree_impl.assert_same_paths([0, 1], (0, 1))
273+
Point1 = collections.namedtuple("Point1", ["x", "y"])
274+
Point2 = collections.namedtuple("Point2", ["x", "y"])
275+
tree_impl.assert_same_paths(Point1(0, 1), Point2(0, 1))
276+
277+
with self.assertRaisesRegex(ValueError, assertion_message):
278+
tree_impl.assert_same_paths(
279+
STRUCTURE1, STRUCTURE_DIFFERENT_NUM_ELEMENTS
280+
)
281+
with self.assertRaisesRegex(ValueError, assertion_message):
282+
tree_impl.assert_same_paths([0, 1], np.array([0, 1]))
283+
with self.assertRaisesRegex(ValueError, assertion_message):
284+
tree_impl.assert_same_paths(0, [0, 1])
285+
with self.assertRaisesRegex(ValueError, assertion_message):
286+
tree_impl.assert_same_paths(STRUCTURE1, STRUCTURE_DIFFERENT_NESTING)
287+
with self.assertRaisesRegex(ValueError, assertion_message):
288+
tree_impl.assert_same_paths([[3], 4], [3, [4]])
289+
with self.assertRaisesRegex(ValueError, assertion_message):
290+
tree_impl.assert_same_paths({"a": 1}, {"b": 1})
291+
structure1_list = [[[1, 2], 3], 4, [5, 6]]
292+
tree_impl.assert_same_paths(STRUCTURE1, structure1_list)
293+
tree_impl.assert_same_paths(STRUCTURE1, STRUCTURE2)
294+
228295
def test_pack_sequence_as(self, tree_impl, is_optree):
229296
structure = {"key3": "", "key1": "", "key2": ""}
230297
flat_sequence = ["value1", "value2", "value3"]

0 commit comments

Comments
 (0)