@@ -55,7 +55,6 @@ def test_is_nested(self, tree_impl, is_optree):
5555
5656    def  test_flatten (self , tree_impl , is_optree ):
5757        structure  =  ((3 , 4 ), 5 , (6 , 7 , (9 , 10 ), 8 ))
58-         flat  =  ["a" , "b" , "c" , "d" , "e" , "f" , "g" , "h" ]
5958
6059        self .assertEqual (
6160            tree_impl .flatten (structure ), [3 , 4 , 5 , 6 , 7 , 9 , 10 , 8 ]
@@ -68,6 +67,48 @@ def test_flatten(self, tree_impl, is_optree):
6867        self .assertEqual ([5 ], tree_impl .flatten (5 ))
6968        self .assertEqual ([np .array ([5 ])], tree_impl .flatten (np .array ([5 ])))
7069
70+     def  test_flatten_with_path (self , tree_impl , is_optree ):
71+         structure  =  {"b" : (0 , 1 ), "a" : [2 , 3 ]}
72+         flat_with_path  =  tree_impl .flatten_with_path (structure )
73+ 
74+         self .assertEqual (
75+             tree_impl .flatten (flat_with_path ),
76+             tree_impl .flatten (
77+                 [(("a" , 0 ), 2 ), (("a" , 1 ), 3 ), (("b" , 0 ), 0 ), (("b" , 1 ), 1 )]
78+             ),
79+         )
80+         point  =  collections .namedtuple ("Point" , ["x" , "y" , "z" ])
81+         structure  =  point (x = (0 , 1 ), y = [2 , 3 ], z = {"a" : 4 })
82+         flat_with_path  =  tree_impl .flatten_with_path (structure )
83+ 
84+         if  is_optree :
85+             # optree doesn't return namedtuple's field name, but the index 
86+             self .assertEqual (
87+                 tree_impl .flatten (flat_with_path ),
88+                 tree_impl .flatten (
89+                     [
90+                         ((0 , 0 ), 0 ),
91+                         ((0 , 1 ), 1 ),
92+                         ((1 , 0 ), 2 ),
93+                         ((1 , 1 ), 3 ),
94+                         ((2 , "a" ), 4 ),
95+                     ]
96+                 ),
97+             )
98+         else :
99+             self .assertEqual (
100+                 tree_impl .flatten (flat_with_path ),
101+                 tree_impl .flatten (
102+                     [
103+                         (("x" , 0 ), 0 ),
104+                         (("x" , 1 ), 1 ),
105+                         (("y" , 0 ), 2 ),
106+                         (("y" , 1 ), 3 ),
107+                         (("z" , "a" ), 4 ),
108+                     ]
109+                 ),
110+             )
111+ 
71112    def  test_flatten_dict_order (self , tree_impl , is_optree ):
72113        ordered  =  collections .OrderedDict (
73114            [("d" , 3 ), ("b" , 1 ), ("a" , 0 ), ("c" , 2 )]
@@ -225,6 +266,32 @@ def test_assert_same_structure(self, tree_impl, is_optree):
225266                STRUCTURE1 , structure1_list , check_types = False 
226267            )
227268
269+     def  test_assert_same_paths (self , tree_impl , is_optree ):
270+         assertion_message  =  "don't have the same paths" 
271+ 
272+         tree_impl .assert_same_paths ([0 , 1 ], (0 , 1 ))
273+         Point1  =  collections .namedtuple ("Point1" , ["x" , "y" ])
274+         Point2  =  collections .namedtuple ("Point2" , ["x" , "y" ])
275+         tree_impl .assert_same_paths (Point1 (0 , 1 ), Point2 (0 , 1 ))
276+ 
277+         with  self .assertRaisesRegex (ValueError , assertion_message ):
278+             tree_impl .assert_same_paths (
279+                 STRUCTURE1 , STRUCTURE_DIFFERENT_NUM_ELEMENTS 
280+             )
281+         with  self .assertRaisesRegex (ValueError , assertion_message ):
282+             tree_impl .assert_same_paths ([0 , 1 ], np .array ([0 , 1 ]))
283+         with  self .assertRaisesRegex (ValueError , assertion_message ):
284+             tree_impl .assert_same_paths (0 , [0 , 1 ])
285+         with  self .assertRaisesRegex (ValueError , assertion_message ):
286+             tree_impl .assert_same_paths (STRUCTURE1 , STRUCTURE_DIFFERENT_NESTING )
287+         with  self .assertRaisesRegex (ValueError , assertion_message ):
288+             tree_impl .assert_same_paths ([[3 ], 4 ], [3 , [4 ]])
289+         with  self .assertRaisesRegex (ValueError , assertion_message ):
290+             tree_impl .assert_same_paths ({"a" : 1 }, {"b" : 1 })
291+         structure1_list  =  [[[1 , 2 ], 3 ], 4 , [5 , 6 ]]
292+         tree_impl .assert_same_paths (STRUCTURE1 , structure1_list )
293+         tree_impl .assert_same_paths (STRUCTURE1 , STRUCTURE2 )
294+ 
228295    def  test_pack_sequence_as (self , tree_impl , is_optree ):
229296        structure  =  {"key3" : "" , "key1" : "" , "key2" : "" }
230297        flat_sequence  =  ["value1" , "value2" , "value3" ]
0 commit comments