66from bigtree .utils .assertions import (
77 assert_dataframe_no_duplicate_attribute ,
88 assert_dataframe_not_empty ,
9- assert_dictionary_not_empty ,
9+ assert_key_not_in_dict_or_df ,
1010 assert_length_not_empty ,
1111 filter_attributes ,
1212 isnull ,
2121__all__ = ["list_to_dag" , "dict_to_dag" , "dataframe_to_dag" ]
2222
2323
24- @optional_dependencies_pandas
2524def list_to_dag (
2625 relations : List [Tuple [str , str ]],
2726 node_type : Type [DAGNode ] = DAGNode ,
@@ -45,13 +44,26 @@ def list_to_dag(
4544 """
4645 assert_length_not_empty (relations , "Input list" , "relations" )
4746
48- relation_data = pd .DataFrame (relations , columns = ["parent" , "child" ])
49- return dataframe_to_dag (
50- relation_data , child_col = "child" , parent_col = "parent" , node_type = node_type
51- )
47+ node_dict : Dict [str , DAGNode ] = dict ()
48+ parent_node = DAGNode ()
49+
50+ for parent_name , child_name in relations :
51+ if parent_name not in node_dict :
52+ parent_node = node_type (parent_name )
53+ node_dict [parent_name ] = parent_node
54+ else :
55+ parent_node = node_dict [parent_name ]
56+ if child_name not in node_dict :
57+ child_node = node_type (child_name )
58+ node_dict [child_name ] = child_node
59+ else :
60+ child_node = node_dict [child_name ]
61+
62+ child_node .parents = [parent_node ]
63+
64+ return parent_node
5265
5366
54- @optional_dependencies_pandas
5567def dict_to_dag (
5668 relation_attrs : Dict [str , Any ],
5769 parent_key : str = "parents" ,
@@ -83,22 +95,36 @@ def dict_to_dag(
8395 Returns:
8496 (DAGNode)
8597 """
86- assert_dictionary_not_empty (relation_attrs , "relation_attrs" )
98+ assert_length_not_empty (relation_attrs , "Dictionary" , "relation_attrs" )
99+
100+ node_dict : Dict [str , DAGNode ] = dict ()
101+ parent_node : DAGNode | None = None
102+
103+ for child_name , node_attrs in relation_attrs .items ():
104+ node_attrs = node_attrs .copy ()
105+ parent_names : List [str ] = []
106+ if parent_key in node_attrs :
107+ parent_names = node_attrs .pop (parent_key )
108+ assert_key_not_in_dict_or_df (node_attrs , ["parent" , "parents" , "children" ])
109+
110+ if child_name in node_dict :
111+ child_node = node_dict [child_name ]
112+ child_node .set_attrs (node_attrs )
113+ else :
114+ child_node = node_type (child_name , ** node_attrs )
115+ node_dict [child_name ] = child_node
116+
117+ for parent_name in parent_names :
118+ parent_node = node_dict .get (parent_name , node_type (parent_name ))
119+ node_dict [parent_name ] = parent_node
120+ child_node .parents = [parent_node ]
87121
88- # Convert dictionary to dataframe
89- data = pd .DataFrame (relation_attrs ).T .rename_axis ("_tmp_child" ).reset_index ()
90- if parent_key not in data :
122+ if parent_node is None :
91123 raise ValueError (
92124 f"Parent key { parent_key } not in dictionary, check `relation_attrs` and `parent_key`"
93125 )
94126
95- data = data .explode (parent_key )
96- return dataframe_to_dag (
97- data ,
98- child_col = "_tmp_child" ,
99- parent_col = parent_key ,
100- node_type = node_type ,
101- )
127+ return parent_node
102128
103129
104130@optional_dependencies_pandas
@@ -164,6 +190,7 @@ def dataframe_to_dag(
164190 attribute_cols = list (data .columns )
165191 attribute_cols .remove (child_col )
166192 attribute_cols .remove (parent_col )
193+ assert_key_not_in_dict_or_df (attribute_cols , ["parent" , "parents" , "children" ])
167194
168195 data = data [[child_col , parent_col ] + attribute_cols ].copy ()
169196
0 commit comments