11from collections import deque
2- from typing import Any , Deque , Dict , List , Type , TypeVar , Union
2+ from typing import Any , Deque , Dict , List , Set , Type , TypeVar , Union
33
44from bigtree .node .basenode import BaseNode
55from bigtree .node .binarynode import BinaryNode
@@ -56,18 +56,21 @@ def recursive_add_child(
5656
5757def prune_tree (
5858 tree : Union [BinaryNodeT , NodeT ],
59- prune_path : str = "" ,
59+ prune_path : Union [List [str ], str ] = "" ,
60+ exact : bool = False ,
6061 sep : str = "/" ,
6162 max_depth : int = 0 ,
6263) -> Union [BinaryNodeT , NodeT ]:
6364 """Prune tree by path or depth, returns the root of a *copy* of the original tree.
6465
6566 For pruning by `prune_path`,
66- All siblings along the prune path will be removed.
67- Prune path name should be unique, can be full path, partial path (trailing part of path), or node name.
67+ - All siblings along the prune path will be removed.
68+ - If ``exact=True``, all descendants of prune path will be removed.
69+ - Prune path can be string (only one path) or a list of strings (multiple paths).
70+ - Prune path name should be unique, can be full path, partial path (trailing part of path), or node name.
6871
6972 For pruning by `max_depth`,
70- All nodes that are beyond `max_depth` will be removed.
73+ - All nodes that are beyond `max_depth` will be removed.
7174
7275 Path should contain ``Node`` name, separated by `sep`.
7376 - For example: Path string "a/b" refers to Node("b") with parent Node("a").
@@ -85,13 +88,33 @@ def prune_tree(
8588 │ └── d
8689 └── e
8790
91+ Prune (default is keep descendants)
92+
8893 >>> root_pruned = prune_tree(root, "a/b")
8994 >>> root_pruned.show()
9095 a
9196 └── b
9297 ├── c
9398 └── d
9499
100+ Prune exact path
101+
102+ >>> root_pruned = prune_tree(root, "a/b", exact=True)
103+ >>> root_pruned.show()
104+ a
105+ └── b
106+
107+ Prune multiple paths
108+
109+ >>> root_pruned = prune_tree(root, ["a/b/d", "a/e"])
110+ >>> root_pruned.show()
111+ a
112+ ├── b
113+ │ └── d
114+ └── e
115+
116+ Prune by depth
117+
95118 >>> root_pruned = prune_tree(root, max_depth=2)
96119 >>> root_pruned.show()
97120 a
@@ -100,31 +123,47 @@ def prune_tree(
100123
101124 Args:
102125 tree (Union[BinaryNode, Node]): existing tree
103- prune_path (str): prune path, all siblings along the prune path will be removed
126+ prune_path (List[str] | str): prune path(s), all siblings along the prune path(s) will be removed
127+ exact (bool): prune path(s) to be exactly the path, defaults to False (descendants of the path are retained)
104128 sep (str): path separator of `prune_path`
105129 max_depth (int): maximum depth of pruned tree, based on `depth` attribute, defaults to None
106130
107131 Returns:
108132 (Union[BinaryNode, Node])
109133 """
110- if not prune_path and not max_depth :
134+ if isinstance (prune_path , str ):
135+ prune_path = [prune_path ] if prune_path else []
136+
137+ if not len (prune_path ) and not max_depth :
111138 raise ValueError ("Please specify either `prune_path` or `max_depth` or both." )
112139
113140 tree_copy = tree .copy ()
114141
115142 # Prune by path (prune bottom-up)
116- if prune_path :
117- prune_path = prune_path .replace (sep , tree .sep )
118- child = find_path (tree_copy , prune_path )
119- if not child :
120- raise NotFoundError (
121- f"Cannot find any node matching path_name ending with { prune_path } "
122- )
123- while child .parent :
124- for other_children in child .parent .children :
125- if other_children != child :
126- other_children .parent = None
127- child = child .parent
143+ if len (prune_path ):
144+ ancestors_to_prune : Set [Union [BinaryNodeT , NodeT ]] = set ()
145+ nodes_to_prune : Set [Union [BinaryNodeT , NodeT ]] = set ()
146+ for path in prune_path :
147+ path = path .replace (sep , tree .sep )
148+ child = find_path (tree_copy , path )
149+ if not child :
150+ raise NotFoundError (
151+ f"Cannot find any node matching path_name ending with { path } "
152+ )
153+ nodes_to_prune .add (child )
154+ ancestors_to_prune .update (list (child .ancestors ))
155+
156+ if exact :
157+ ancestors_to_prune .update (nodes_to_prune )
158+
159+ for node in ancestors_to_prune :
160+ for child in node .children :
161+ if (
162+ child
163+ and child not in ancestors_to_prune
164+ and child not in nodes_to_prune
165+ ):
166+ child .parent = None
128167
129168 # Prune by depth (prune top-down)
130169 if max_depth :
0 commit comments