Skip to content

Commit abd5a3a

Browse files
authored
Merge pull request #248 from kayjan/fix-type-hints
Fix type hints
2 parents f826e13 + 7fe28b7 commit abd5a3a

File tree

6 files changed

+87
-75
lines changed

6 files changed

+87
-75
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88
### Changed:
9+
- Binary Tree Constructor: Type hints to return more generic TypeVar for use with subclasses.
10+
- DAG Constructor: Type hints to return more generic TypeVar for use with subclasses.
11+
- [#247] Tree Construct: Type hints to return more generic TypeVar for use with subclasses.
12+
- Tree Modifier: Type hints to return more generic TypeVar for use with subclasses.
913
- Misc: Documentation to include more contribution information and guidelines.
1014

1115
## [0.18.2] - 2024-06-01

bigtree/binarytree/construct.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
from typing import List, Type
1+
from typing import List, Type, TypeVar
22

33
from bigtree.node.binarynode import BinaryNode
4+
from bigtree.utils.assertions import assert_length_not_empty
45

56
__all__ = ["list_to_binarytree"]
67

7-
from bigtree.utils.assertions import assert_length_not_empty
8+
T = TypeVar("T", bound=BinaryNode)
89

910

1011
def list_to_binarytree(
11-
heapq_list: List[int], node_type: Type[BinaryNode] = BinaryNode
12-
) -> BinaryNode:
12+
heapq_list: List[int],
13+
node_type: Type[T] = BinaryNode, # type: ignore[assignment]
14+
) -> T:
1315
"""Construct tree from a list of numbers (int or float) in heapq format.
1416
1517
Examples:
@@ -46,6 +48,6 @@ def list_to_binarytree(
4648
for idx, num in enumerate(heapq_list):
4749
if idx:
4850
parent_idx = int((idx + 1) / 2) - 1
49-
node = node_type(num, parent=node_list[parent_idx]) # type: ignore
51+
node = node_type(num, parent=node_list[parent_idx])
5052
node_list.append(node)
5153
return root_node

bigtree/dag/construct.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, Dict, List, Tuple, Type
3+
from typing import Any, Dict, List, Tuple, Type, TypeVar
44

55
from bigtree.node.dagnode import DAGNode
66
from bigtree.utils.assertions import (
@@ -20,11 +20,13 @@
2020

2121
__all__ = ["list_to_dag", "dict_to_dag", "dataframe_to_dag"]
2222

23+
T = TypeVar("T", bound=DAGNode)
24+
2325

2426
def list_to_dag(
2527
relations: List[Tuple[str, str]],
26-
node_type: Type[DAGNode] = DAGNode,
27-
) -> DAGNode:
28+
node_type: Type[T] = DAGNode, # type: ignore[assignment]
29+
) -> T:
2830
"""Construct DAG from list of tuples containing parent-child names.
2931
Note that node names must be unique.
3032
@@ -44,8 +46,8 @@ def list_to_dag(
4446
"""
4547
assert_length_not_empty(relations, "Input list", "relations")
4648

47-
node_dict: Dict[str, DAGNode] = dict()
48-
parent_node = DAGNode()
49+
node_dict: Dict[str, T] = dict()
50+
parent_node: T = DAGNode() # type: ignore[assignment]
4951

5052
for parent_name, child_name in relations:
5153
if parent_name not in node_dict:
@@ -67,8 +69,8 @@ def list_to_dag(
6769
def dict_to_dag(
6870
relation_attrs: Dict[str, Any],
6971
parent_key: str = "parents",
70-
node_type: Type[DAGNode] = DAGNode,
71-
) -> DAGNode:
72+
node_type: Type[T] = DAGNode, # type: ignore[assignment]
73+
) -> T:
7274
"""Construct DAG from nested dictionary, ``key``: child name, ``value``: dictionary of parent names, attribute
7375
name, and attribute value.
7476
Note that node names must be unique.
@@ -97,8 +99,8 @@ def dict_to_dag(
9799
"""
98100
assert_length_not_empty(relation_attrs, "Dictionary", "relation_attrs")
99101

100-
node_dict: Dict[str, DAGNode] = dict()
101-
parent_node: DAGNode | None = None
102+
node_dict: Dict[str, T] = dict()
103+
parent_node: T | None = None
102104

103105
for child_name, node_attrs in relation_attrs.items():
104106
node_attrs = node_attrs.copy()
@@ -133,8 +135,8 @@ def dataframe_to_dag(
133135
child_col: str = "",
134136
parent_col: str = "",
135137
attribute_cols: List[str] = [],
136-
node_type: Type[DAGNode] = DAGNode,
137-
) -> DAGNode:
138+
node_type: Type[T] = DAGNode, # type: ignore[assignment]
139+
) -> T:
138140
"""Construct DAG from pandas DataFrame.
139141
Note that node names must be unique.
140142
@@ -200,8 +202,8 @@ def dataframe_to_dag(
200202
if sum(data[child_col].isnull()):
201203
raise ValueError(f"Child name cannot be empty, check column: {child_col}")
202204

203-
node_dict: Dict[str, DAGNode] = dict()
204-
parent_node = DAGNode()
205+
node_dict: Dict[str, T] = dict()
206+
parent_node: T = DAGNode() # type: ignore[assignment]
205207

206208
for row in data.reset_index(drop=True).to_dict(orient="index").values():
207209
child_name = row[child_col]

bigtree/tree/construct.py

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import re
44
from collections import OrderedDict, defaultdict
5-
from typing import Any, Dict, List, Optional, Tuple, Type
5+
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
66

77
from bigtree.node.node import Node
88
from bigtree.tree.search import find_child_by_name, find_name
@@ -51,14 +51,16 @@
5151
"newick_to_tree",
5252
]
5353

54+
T = TypeVar("T", bound=Node)
55+
5456

5557
def add_path_to_tree(
56-
tree: Node,
58+
tree: T,
5759
path: str,
5860
sep: str = "/",
5961
duplicate_name_allowed: bool = True,
6062
node_attrs: Dict[str, Any] = {},
61-
) -> Node:
63+
) -> T:
6264
"""Add nodes and attributes to existing tree *in-place*, return node of path added.
6365
Adds to existing tree from list of path strings.
6466
@@ -136,11 +138,11 @@ def add_path_to_tree(
136138

137139

138140
def add_dict_to_tree_by_path(
139-
tree: Node,
141+
tree: T,
140142
path_attrs: Dict[str, Dict[str, Any]],
141143
sep: str = "/",
142144
duplicate_name_allowed: bool = True,
143-
) -> Node:
145+
) -> T:
144146
"""Add nodes and attributes to tree *in-place*, return root of tree.
145147
Adds to existing tree from nested dictionary, ``key``: path, ``value``: dict of attribute name and attribute value.
146148
@@ -208,7 +210,7 @@ def add_dict_to_tree_by_path(
208210
return root_node
209211

210212

211-
def add_dict_to_tree_by_name(tree: Node, name_attrs: Dict[str, Dict[str, Any]]) -> Node:
213+
def add_dict_to_tree_by_name(tree: T, name_attrs: Dict[str, Dict[str, Any]]) -> T:
212214
"""Add attributes to existing tree *in-place*.
213215
Adds to existing tree from nested dictionary, ``key``: name, ``value``: dict of attribute name and attribute value.
214216
@@ -254,13 +256,13 @@ def add_dict_to_tree_by_name(tree: Node, name_attrs: Dict[str, Dict[str, Any]])
254256

255257

256258
def add_dataframe_to_tree_by_path(
257-
tree: Node,
259+
tree: T,
258260
data: pd.DataFrame,
259261
path_col: str = "",
260262
attribute_cols: List[str] = [],
261263
sep: str = "/",
262264
duplicate_name_allowed: bool = True,
263-
) -> Node:
265+
) -> T:
264266
"""Add nodes and attributes to tree *in-place*, return root of tree.
265267
Adds to existing tree from pandas DataFrame.
266268
@@ -350,11 +352,11 @@ def add_dataframe_to_tree_by_path(
350352

351353

352354
def add_dataframe_to_tree_by_name(
353-
tree: Node,
355+
tree: T,
354356
data: pd.DataFrame,
355357
name_col: str = "",
356358
attribute_cols: List[str] = [],
357-
) -> Node:
359+
) -> T:
358360
"""Add attributes to existing tree *in-place*.
359361
Adds to existing tree from pandas DataFrame.
360362
@@ -418,13 +420,13 @@ def add_dataframe_to_tree_by_name(
418420

419421

420422
def add_polars_to_tree_by_path(
421-
tree: Node,
423+
tree: T,
422424
data: pl.DataFrame,
423425
path_col: str = "",
424426
attribute_cols: List[str] = [],
425427
sep: str = "/",
426428
duplicate_name_allowed: bool = True,
427-
) -> Node:
429+
) -> T:
428430
"""Add nodes and attributes to tree *in-place*, return root of tree.
429431
Adds to existing tree from polars DataFrame.
430432
@@ -516,11 +518,11 @@ def add_polars_to_tree_by_path(
516518

517519

518520
def add_polars_to_tree_by_name(
519-
tree: Node,
521+
tree: T,
520522
data: pl.DataFrame,
521523
name_col: str = "",
522524
attribute_cols: List[str] = [],
523-
) -> Node:
525+
) -> T:
524526
"""Add attributes to existing tree *in-place*.
525527
Adds to existing tree from polars DataFrame.
526528
@@ -586,8 +588,8 @@ def add_polars_to_tree_by_name(
586588
def str_to_tree(
587589
tree_string: str,
588590
tree_prefix_list: List[str] = [],
589-
node_type: Type[Node] = Node,
590-
) -> Node:
591+
node_type: Type[T] = Node, # type: ignore[assignment]
592+
) -> T:
591593
r"""Construct tree from tree string
592594
593595
Examples:
@@ -655,8 +657,8 @@ def list_to_tree(
655657
paths: List[str],
656658
sep: str = "/",
657659
duplicate_name_allowed: bool = True,
658-
node_type: Type[Node] = Node,
659-
) -> Node:
660+
node_type: Type[T] = Node, # type: ignore[assignment]
661+
) -> T:
660662
"""Construct tree from list of path strings.
661663
662664
Path should contain ``Node`` name, separated by `sep`.
@@ -715,8 +717,8 @@ def list_to_tree(
715717
def list_to_tree_by_relation(
716718
relations: List[Tuple[str, str]],
717719
allow_duplicates: bool = False,
718-
node_type: Type[Node] = Node,
719-
) -> Node:
720+
node_type: Type[T] = Node, # type: ignore[assignment]
721+
) -> T:
720722
"""Construct tree from list of tuple containing parent-child names.
721723
722724
Root node is inferred when parent is empty, or when name appears as parent but not as child.
@@ -764,8 +766,8 @@ def dict_to_tree(
764766
path_attrs: Dict[str, Any],
765767
sep: str = "/",
766768
duplicate_name_allowed: bool = True,
767-
node_type: Type[Node] = Node,
768-
) -> Node:
769+
node_type: Type[T] = Node, # type: ignore[assignment]
770+
) -> T:
769771
"""Construct tree from nested dictionary using path,
770772
``key``: path, ``value``: dict of attribute name and attribute value.
771773
@@ -854,8 +856,8 @@ def nested_dict_to_tree(
854856
node_attrs: Dict[str, Any],
855857
name_key: str = "name",
856858
child_key: str = "children",
857-
node_type: Type[Node] = Node,
858-
) -> Node:
859+
node_type: Type[T] = Node, # type: ignore[assignment]
860+
) -> T:
859861
"""Construct tree from nested recursive dictionary.
860862
861863
- ``key``: `name_key`, `child_key`, or any attributes key.
@@ -901,8 +903,8 @@ def nested_dict_to_tree(
901903
assert_length_not_empty(node_attrs, "Dictionary", "node_attrs")
902904

903905
def _recursive_add_child(
904-
child_dict: Dict[str, Any], parent_node: Optional[Node] = None
905-
) -> Node:
906+
child_dict: Dict[str, Any], parent_node: Optional[T] = None
907+
) -> T:
906908
"""Recursively add child to tree, given child attributes and parent node.
907909
908910
Args:
@@ -934,8 +936,8 @@ def dataframe_to_tree(
934936
attribute_cols: List[str] = [],
935937
sep: str = "/",
936938
duplicate_name_allowed: bool = True,
937-
node_type: Type[Node] = Node,
938-
) -> Node:
939+
node_type: Type[T] = Node, # type: ignore[assignment]
940+
) -> T:
939941
"""Construct tree from pandas DataFrame using path, return root of tree.
940942
941943
`path_col` and `attribute_cols` specify columns for node path and attributes to construct tree.
@@ -1040,8 +1042,8 @@ def dataframe_to_tree_by_relation(
10401042
parent_col: str = "",
10411043
attribute_cols: List[str] = [],
10421044
allow_duplicates: bool = False,
1043-
node_type: Type[Node] = Node,
1044-
) -> Node:
1045+
node_type: Type[T] = Node, # type: ignore[assignment]
1046+
) -> T:
10451047
"""Construct tree from pandas DataFrame using parent and child names, return root of tree.
10461048
10471049
Root node is inferred when parent name is empty, or when name appears in parent column but not in child column.
@@ -1138,7 +1140,7 @@ def _retrieve_attr(_row: Dict[str, Any]) -> Dict[str, Any]:
11381140
node_attrs["name"] = _row[child_col]
11391141
return node_attrs
11401142

1141-
def _recursive_add_child(parent_node: Node) -> None:
1143+
def _recursive_add_child(parent_node: T) -> None:
11421144
"""Recursive add child to tree, given current node.
11431145
11441146
Args:
@@ -1168,8 +1170,8 @@ def polars_to_tree(
11681170
attribute_cols: List[str] = [],
11691171
sep: str = "/",
11701172
duplicate_name_allowed: bool = True,
1171-
node_type: Type[Node] = Node,
1172-
) -> Node:
1173+
node_type: Type[T] = Node, # type: ignore[assignment]
1174+
) -> T:
11731175
"""Construct tree from polars DataFrame using path, return root of tree.
11741176
11751177
`path_col` and `attribute_cols` specify columns for node path and attributes to construct tree.
@@ -1275,8 +1277,8 @@ def polars_to_tree_by_relation(
12751277
parent_col: str = "",
12761278
attribute_cols: List[str] = [],
12771279
allow_duplicates: bool = False,
1278-
node_type: Type[Node] = Node,
1279-
) -> Node:
1280+
node_type: Type[T] = Node, # type: ignore[assignment]
1281+
) -> T:
12801282
"""Construct tree from polars DataFrame using parent and child names, return root of tree.
12811283
12821284
Root node is inferred when parent name is empty, or when name appears in parent column but not in child column.
@@ -1373,7 +1375,7 @@ def _retrieve_attr(_row: Dict[str, Any]) -> Dict[str, Any]:
13731375
node_attrs["name"] = _row[child_col]
13741376
return node_attrs
13751377

1376-
def _recursive_add_child(parent_node: Node) -> None:
1378+
def _recursive_add_child(parent_node: T) -> None:
13771379
"""Recursive add child to tree, given current node.
13781380
13791381
Args:
@@ -1402,8 +1404,8 @@ def newick_to_tree(
14021404
tree_string: str,
14031405
length_attr: str = "length",
14041406
attr_prefix: str = "&&NHX:",
1405-
node_type: Type[Node] = Node,
1406-
) -> Node:
1407+
node_type: Type[T] = Node, # type: ignore[assignment]
1408+
) -> T:
14071409
"""Construct tree from Newick notation, return root of tree.
14081410
14091411
In the Newick Notation (or New Hampshire Notation)
@@ -1460,24 +1462,24 @@ def newick_to_tree(
14601462
assert_length_not_empty(tree_string, "Tree string", "tree_string")
14611463

14621464
# Store results (for tracking)
1463-
depth_nodes: Dict[int, List[Node]] = defaultdict(list)
1465+
depth_nodes: Dict[int, List[T]] = defaultdict(list)
14641466
unlabelled_node_counter: int = 0
14651467
current_depth: int = 1
14661468
tree_string_idx: int = 0
14671469

14681470
# Store states (for assertions and checks)
14691471
current_state: NewickState = NewickState.PARSE_STRING
1470-
current_node: Optional[Node] = None
1472+
current_node: Optional[T] = None
14711473
cumulative_string: str = ""
14721474
cumulative_string_value: str = ""
14731475

14741476
def _create_node(
1475-
_new_node: Optional[Node],
1477+
_new_node: Optional[T],
14761478
_cumulative_string: str,
14771479
_unlabelled_node_counter: int,
1478-
_depth_nodes: Dict[int, List[Node]],
1480+
_depth_nodes: Dict[int, List[T]],
14791481
_current_depth: int,
1480-
) -> Tuple[Node, int]:
1482+
) -> Tuple[T, int]:
14811483
"""Create node at checkpoint.
14821484
14831485
Args:

0 commit comments

Comments
 (0)