33from typing import Any , Dict , List , Tuple , Type
44
55from bigtree .node .dagnode import DAGNode
6+ from bigtree .utils .assertions import (
7+ assert_dataframe_no_duplicate_attribute ,
8+ assert_dataframe_not_empty ,
9+ assert_dictionary_not_empty ,
10+ assert_length_not_empty ,
11+ filter_attributes ,
12+ isnull ,
13+ )
614from bigtree .utils .exceptions import optional_dependencies_pandas
715
816try :
@@ -35,15 +43,15 @@ def list_to_dag(
3543 Returns:
3644 (DAGNode)
3745 """
38- if not len (relations ):
39- raise ValueError ("Input list does not contain any data, check `relations`" )
46+ assert_length_not_empty (relations , "Input list" , "relations" )
4047
4148 relation_data = pd .DataFrame (relations , columns = ["parent" , "child" ])
4249 return dataframe_to_dag (
4350 relation_data , child_col = "child" , parent_col = "parent" , node_type = node_type
4451 )
4552
4653
54+ @optional_dependencies_pandas
4755def dict_to_dag (
4856 relation_attrs : Dict [str , Any ],
4957 parent_key : str = "parents" ,
@@ -75,8 +83,7 @@ def dict_to_dag(
7583 Returns:
7684 (DAGNode)
7785 """
78- if not len (relation_attrs ):
79- raise ValueError ("Dictionary does not contain any data, check `relation_attrs`" )
86+ assert_dictionary_not_empty (relation_attrs , "relation_attrs" )
8087
8188 # Convert dictionary to dataframe
8289 data = pd .DataFrame (relation_attrs ).T .rename_axis ("_tmp_child" ).reset_index ()
@@ -110,6 +117,8 @@ def dataframe_to_dag(
110117 - If columns are not specified, `child_col` takes first column, `parent_col` takes second column, and all other
111118 columns are `attribute_cols`.
112119
120+ Only attributes in `attribute_cols` with non-null values will be added to the tree.
121+
113122 Examples:
114123 >>> import pandas as pd
115124 >>> from bigtree import dataframe_to_dag, dag_iterator
@@ -141,12 +150,7 @@ def dataframe_to_dag(
141150 Returns:
142151 (DAGNode)
143152 """
144- data = data .copy ()
145-
146- if not len (data .columns ):
147- raise ValueError ("Data does not contain any columns, check `data`" )
148- if not len (data ):
149- raise ValueError ("Data does not contain any rows, check `data`" )
153+ assert_dataframe_not_empty (data )
150154
151155 if not child_col :
152156 child_col = data .columns [0 ]
@@ -160,27 +164,12 @@ def dataframe_to_dag(
160164 attribute_cols = list (data .columns )
161165 attribute_cols .remove (child_col )
162166 attribute_cols .remove (parent_col )
163- elif any ([col not in data .columns for col in attribute_cols ]):
164- raise ValueError (
165- f"One or more attribute column(s) not in data, check `attribute_cols`: { attribute_cols } "
166- )
167167
168- data_check = data .copy ()[[child_col , parent_col ] + attribute_cols ].drop_duplicates (
169- subset = [child_col ] + attribute_cols
170- )
171- _duplicate_check = (
172- data_check [child_col ]
173- .value_counts ()
174- .to_frame ("counts" )
175- .rename_axis (child_col )
176- .reset_index ()
168+ data = data [[child_col , parent_col ] + attribute_cols ].copy ()
169+
170+ assert_dataframe_no_duplicate_attribute (
171+ data , "child name" , child_col , attribute_cols
177172 )
178- _duplicate_check = _duplicate_check [_duplicate_check ["counts" ] > 1 ]
179- if len (_duplicate_check ):
180- raise ValueError (
181- f"There exists duplicate child name with different attributes\n "
182- f"Check { _duplicate_check } "
183- )
184173 if sum (data [child_col ].isnull ()):
185174 raise ValueError (f"Child name cannot be empty, check column: { child_col } " )
186175
@@ -190,15 +179,14 @@ def dataframe_to_dag(
190179 for row in data .reset_index (drop = True ).to_dict (orient = "index" ).values ():
191180 child_name = row [child_col ]
192181 parent_name = row [parent_col ]
193- node_attrs = row .copy ()
194- del node_attrs [child_col ]
195- del node_attrs [parent_col ]
196- node_attrs = {k : v for k , v in node_attrs .items () if not pd .isnull (v )}
197- child_node = node_dict .get (child_name , node_type (child_name ))
182+ node_attrs = filter_attributes (
183+ row , omit_keys = ["name" , child_col , parent_col ], omit_null_values = True
184+ )
185+ child_node = node_dict .get (child_name , node_type (child_name , ** node_attrs ))
198186 child_node .set_attrs (node_attrs )
199187 node_dict [child_name ] = child_node
200188
201- if not pd . isnull (parent_name ):
189+ if not isnull (parent_name ):
202190 parent_node = node_dict .get (parent_name , node_type (parent_name ))
203191 node_dict [parent_name ] = parent_node
204192 child_node .parents = [parent_node ]
0 commit comments