Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bigtree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
tree_to_dot,
tree_to_mermaid,
tree_to_nested_dict,
tree_to_newick,
tree_to_pillow,
yield_tree,
)
Expand Down
34 changes: 34 additions & 0 deletions bigtree/tree/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"tree_to_dot",
"tree_to_pillow",
"tree_to_mermaid",
"tree_to_newick",
]

T = TypeVar("T", bound=Node)
Expand Down Expand Up @@ -1156,3 +1157,36 @@ def get_attr(
flows="\n".join(flows),
styles="\n".join(styles),
)


def tree_to_newick(tree: T) -> str:
"""Export tree to Newick notation.

>>> from bigtree import Node, tree_to_newick
>>> root = Node("a", age=90)
>>> b = Node("b", age=65, parent=root)
>>> c = Node("c", age=60, parent=root)
>>> d = Node("d", age=40, parent=b)
>>> e = Node("e", age=35, parent=b)
>>> root.show()
a
├── b
│ ├── d
│ └── e
└── c

>>> tree_to_newick(root)
'((d,e)b,c)a'

Args:
tree (Node): tree to be exported

Returns:
(str)
"""
if not tree:
return ""
if tree.is_leaf:
return tree.node_name
children_newick = ",".join(tree_to_newick(child) for child in tree.children)
return f"({children_newick}){tree.node_name}"
77 changes: 58 additions & 19 deletions bigtree/tree/helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import deque
from typing import Any, Deque, Dict, List, Type, TypeVar, Union
from typing import Any, Deque, Dict, List, Set, Type, TypeVar, Union

from bigtree.node.basenode import BaseNode
from bigtree.node.binarynode import BinaryNode
Expand Down Expand Up @@ -56,18 +56,21 @@ def recursive_add_child(

def prune_tree(
tree: Union[BinaryNodeT, NodeT],
prune_path: str = "",
prune_path: Union[List[str], str] = "",
exact: bool = False,
sep: str = "/",
max_depth: int = 0,
) -> Union[BinaryNodeT, NodeT]:
"""Prune tree by path or depth, returns the root of a *copy* of the original tree.

For pruning by `prune_path`,
All siblings along the prune path will be removed.
Prune path name should be unique, can be full path, partial path (trailing part of path), or node name.
- All siblings along the prune path will be removed.
- If ``exact=True``, all descendants of prune path will be removed.
- Prune path can be string (only one path) or a list of strings (multiple paths).
- Prune path name should be unique, can be full path, partial path (trailing part of path), or node name.

For pruning by `max_depth`,
All nodes that are beyond `max_depth` will be removed.
- All nodes that are beyond `max_depth` will be removed.

Path should contain ``Node`` name, separated by `sep`.
- For example: Path string "a/b" refers to Node("b") with parent Node("a").
Expand All @@ -85,13 +88,33 @@ def prune_tree(
│ └── d
└── e

Prune (default is keep descendants)

>>> root_pruned = prune_tree(root, "a/b")
>>> root_pruned.show()
a
└── b
├── c
└── d

Prune exact path

>>> root_pruned = prune_tree(root, "a/b", exact=True)
>>> root_pruned.show()
a
└── b

Prune multiple paths

>>> root_pruned = prune_tree(root, ["a/b/d", "a/e"])
>>> root_pruned.show()
a
├── b
│ └── d
└── e

Prune by depth

>>> root_pruned = prune_tree(root, max_depth=2)
>>> root_pruned.show()
a
Expand All @@ -100,31 +123,47 @@ def prune_tree(

Args:
tree (Union[BinaryNode, Node]): existing tree
prune_path (str): prune path, all siblings along the prune path will be removed
prune_path (List[str] | str): prune path(s), all siblings along the prune path(s) will be removed
exact (bool): prune path(s) to be exactly the path, defaults to False (descendants of the path are retained)
sep (str): path separator of `prune_path`
max_depth (int): maximum depth of pruned tree, based on `depth` attribute, defaults to None

Returns:
(Union[BinaryNode, Node])
"""
if not prune_path and not max_depth:
if isinstance(prune_path, str):
prune_path = [prune_path] if prune_path else []

if not len(prune_path) and not max_depth:
raise ValueError("Please specify either `prune_path` or `max_depth` or both.")

tree_copy = tree.copy()

# Prune by path (prune bottom-up)
if prune_path:
prune_path = prune_path.replace(sep, tree.sep)
child = find_path(tree_copy, prune_path)
if not child:
raise NotFoundError(
f"Cannot find any node matching path_name ending with {prune_path}"
)
while child.parent:
for other_children in child.parent.children:
if other_children != child:
other_children.parent = None
child = child.parent
if len(prune_path):
ancestors_to_prune: Set[Union[BinaryNodeT, NodeT]] = set()
nodes_to_prune: Set[Union[BinaryNodeT, NodeT]] = set()
for path in prune_path:
path = path.replace(sep, tree.sep)
child = find_path(tree_copy, path)
if not child:
raise NotFoundError(
f"Cannot find any node matching path_name ending with {path}"
)
nodes_to_prune.add(child)
ancestors_to_prune.update(list(child.ancestors))

if exact:
ancestors_to_prune.update(nodes_to_prune)

for node in ancestors_to_prune:
for child in node.children:
if (
child
and child not in ancestors_to_prune
and child not in nodes_to_prune
):
child.parent = None

# Prune by depth (prune top-down)
if max_depth:
Expand Down
8 changes: 8 additions & 0 deletions docs/source/bigtree/tree/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Export Tree to list, dictionary, and pandas DataFrame.
- Method
* - Command Line / Others
- `print_tree`, `yield_tree`
* - String
- `tree_to_newick`
* - Dictionary
- `tree_to_dict`, `tree_to_nested_dict`
* - DataFrame
Expand Down Expand Up @@ -46,6 +48,12 @@ While exporting to another data type, methods can take in arguments to determine
- No, but can specify subtree
- No
- Tree style
* - `tree_to_newick`
- No
- No
- No
- No
- N/A
* - `tree_to_dict`
- Yes with `attr_dict` or `all_attrs`
- Yes with `max_depth`
Expand Down
9 changes: 9 additions & 0 deletions tests/binarytree/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
tree_to_dict,
tree_to_dot,
tree_to_nested_dict,
tree_to_newick,
)
from tests.conftest import assert_print_statement
from tests.test_constants import Constants
Expand Down Expand Up @@ -89,3 +90,11 @@ def test_tree_to_dot(binarytree_node):
assert (
expected_str in actual
), f"Expected {expected_str} not in actual string"


class TestTreeToNewick:
@staticmethod
def test_tree_to_newick(binarytree_node):
newick_str = tree_to_newick(binarytree_node)
expected_str = """(((,8)4,5)2,(6,7)3)1"""
assert newick_str == expected_str
32 changes: 32 additions & 0 deletions tests/binarytree/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ def test_prune_tree(binarytree_node):
None,
), "Check children (depth 5)"

@staticmethod
def test_prune_tree_exact(binarytree_node):
# Pruned tree is 1/2/4
tree_prune = prune_tree(binarytree_node, "1/2/4", exact=True)

assert tree_prune.left.val == 2, "Check left child of tree_prune (depth 2)"
assert not tree_prune.right, "Check right child of tree_prune (depth 2)"
assert tree_prune.left.left.val == 4, "Check left child (depth 3)"
assert not tree_prune.left.right, "Check right child (depth 3)"
assert tree_prune.left.left.children == (
None,
None,
), "Check children (depth 4)"

@staticmethod
def test_prune_tree_second_child(binarytree_node):
# Pruned tree is 1/3/6, 1/3/7
Expand All @@ -48,6 +62,24 @@ def test_prune_tree_second_child(binarytree_node):
None,
), "Check children (depth 4)"

@staticmethod
def test_prune_tree_list(binarytree_node):
# Pruned tree is 1/3/6, 1/3/7
tree_prune = prune_tree(binarytree_node, ["1/3/6", "1/3/7"])

assert not tree_prune.left, "Check left child of tree_prune (depth 2)"
assert tree_prune.right.val == 3, "Check right child of tree_prune (depth 2)"
assert tree_prune.right.left.val == 6, "Check left child (depth 3)"
assert tree_prune.right.right.val == 7, "Check right child (depth 3)"
assert tree_prune.right.left.children == (
None,
None,
), "Check children (depth 4)"
assert tree_prune.right.right.children == (
None,
None,
), "Check children (depth 4)"


class TestTreeDiff:
@staticmethod
Expand Down
9 changes: 9 additions & 0 deletions tests/tree/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
tree_to_dot,
tree_to_mermaid,
tree_to_nested_dict,
tree_to_newick,
tree_to_pillow,
)
from tests.conftest import assert_print_statement
Expand Down Expand Up @@ -1376,3 +1377,11 @@ def get_node_attr(node):
expected_str = """```mermaid\n%%{ init: { \'flowchart\': { \'curve\': \'basis\' } } }%%\nflowchart TB\n0("a") --> 0-0("b")\n0-0 --> 0-0-0("d")\n0-0-0:::class0-0-0-0 --> 0-0-0-0("g")\n0-0 --> 0-0-1("e")\n0-0-1:::class0-0-1-0 --> 0-0-1-0("h")\n0("a") --> 0-1("c")\n0-1 --> 0-1-0("f")\nclassDef default stroke-width:1\nclassDef class0-0-0-0 fill:red,stroke:black,stroke-width:2\nclassDef class0-0-1-0 fill:red,stroke:black,stroke-width:2\n```"""

assert mermaid_md == expected_str


class TestTreeToNewick:
@staticmethod
def test_tree_to_newick(tree_node):
newick_str = tree_to_newick(tree_node)
expected_str = """((d,(g,h)e)b,(f)c)a"""
assert newick_str == expected_str
53 changes: 53 additions & 0 deletions tests/tree/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,29 @@ def test_prune_tree(tree_node):
assert len(tree_prune.children[0].children[0].children) == 0
assert len(tree_prune.children[0].children[1].children) == 2

@staticmethod
def test_prune_tree_exact(tree_node):
# Pruned tree is a/b/e/g
tree_prune = prune_tree(tree_node, "a/b/e/g", exact=True)

assert_tree_structure_basenode_root(tree_node)
assert_tree_structure_basenode_root_attr(tree_node)
assert len(list(tree_prune.children)) == 1
assert len(tree_prune.children[0].children) == 1
assert len(tree_prune.children[0].children[0].children) == 1

@staticmethod
def test_prune_tree_list(tree_node):
# Pruned tree is a/b/d, a/b/e/g, a/b/e/h
tree_prune = prune_tree(tree_node, ["a/b/d", "a/b/e/g", "a/b/e/h"])

assert_tree_structure_basenode_root(tree_node)
assert_tree_structure_basenode_root_attr(tree_node)
assert len(list(tree_prune.children)) == 1
assert len(tree_prune.children[0].children) == 2
assert len(tree_prune.children[0].children[0].children) == 0
assert len(tree_prune.children[0].children[1].children) == 2

@staticmethod
def test_prune_tree_path_and_depth(tree_node):
# Pruned tree is a/b/d, a/b/e (a/b/e/g, a/b/e/h pruned away)
Expand All @@ -76,6 +99,36 @@ def test_prune_tree_path_and_depth(tree_node):
len(tree_prune.children[0].children[1].children) == 0
), "Depth at 4 is not pruned away"

@staticmethod
def test_prune_tree_path_and_depth_and_list(tree_node):
# Pruned tree is a/b/d, a/b/e (a/b/e/g, a/b/e/h pruned away)
tree_prune = prune_tree(tree_node, ["a/b/d", "a/b/e/g", "a/b/e/h"], max_depth=3)

assert_tree_structure_basenode_root(tree_node)
assert_tree_structure_basenode_root_attr(tree_node)
assert len(list(tree_prune.children)) == 1
assert len(tree_prune.children[0].children) == 2
assert len(tree_prune.children[0].children[0].children) == 0
assert (
len(tree_prune.children[0].children[1].children) == 0
), "Depth at 4 is not pruned away"

@staticmethod
def test_prune_tree_path_and_depth_and_list_and_exact(tree_node):
# Pruned tree is a/b/d, a/b/e (a/b/e/g, a/b/e/h pruned away)
tree_prune = prune_tree(
tree_node, ["a/b/d", "a/b/e/g", "a/b/e/h"], exact=True, max_depth=3
)

assert_tree_structure_basenode_root(tree_node)
assert_tree_structure_basenode_root_attr(tree_node)
assert len(list(tree_prune.children)) == 1
assert len(tree_prune.children[0].children) == 2
assert len(tree_prune.children[0].children[0].children) == 0
assert (
len(tree_prune.children[0].children[1].children) == 0
), "Depth at 4 is not pruned away"

@staticmethod
def test_prune_tree_only_depth(tree_node):
# Pruned tree is a/b, a/c (a/b/e/g, a/b/e/h, a/c/f pruned away)
Expand Down