9
9
10
10
from piqtree2 .exceptions import ParseIqTreeError
11
11
from piqtree2 .iqtree ._decorator import iqtree_func
12
- from piqtree2 .model import Model
12
+ from piqtree2 .model import DnaModel , Model
13
13
14
14
iq_build_tree = iqtree_func (iq_build_tree , hide_files = True )
15
15
iq_fit_tree = iqtree_func (iq_fit_tree , hide_files = True )
@@ -32,7 +32,19 @@ def _intrude_edge_params(tree: cogent3.PhyloNode, **kwargs: dict) -> None:
32
32
node .params .update (kwargs )
33
33
34
34
35
- def _process_tree_yaml (tree_yaml : dict , names : Sequence [str ]) -> cogent3 .PhyloNode :
35
+ def _reform_rate_params (rate_pars : dict , model : Model ) -> dict :
36
+ if model .substitution_model in {DnaModel .K80 , DnaModel .HKY }:
37
+ rate_pars = {"kappa" : rate_pars ["A/G" ]}
38
+ elif model .substitution_model is DnaModel .TN :
39
+ rate_pars = {"kappa_r" : rate_pars ["A/G" ], "kappa_y" : rate_pars ["C/T" ]}
40
+ elif model .substitution_model is DnaModel .GTR :
41
+ del rate_pars ["G/T" ]
42
+ return rate_pars
43
+
44
+
45
+ def _process_tree_yaml (
46
+ tree_yaml : dict , names : Sequence [str ], model : Model ,
47
+ ) -> cogent3 .PhyloNode :
36
48
newick = tree_yaml ["PhyloTree" ]["newick" ]
37
49
38
50
tree = cogent3 .make_tree (newick )
@@ -63,9 +75,21 @@ def _process_tree_yaml(tree_yaml: dict, names: Sequence[str]) -> cogent3.PhyloNo
63
75
),
64
76
),
65
77
}
66
- _intrude_edge_params (
67
- tree , ** rate_pars , ** motif_pars ,
68
- ) # add global rate parameters to the edges
78
+
79
+ if model .substitution_model in {DnaModel .JC , DnaModel .F81 }:
80
+ _intrude_edge_params (
81
+ tree ,
82
+ ** motif_pars , # skip rate_pars since rate parameters are constant in JC and F81
83
+ )
84
+ else :
85
+ rate_pars = _reform_rate_params (
86
+ rate_pars , model ,
87
+ ) # reform rate parameters in cogent3 way
88
+ _intrude_edge_params (
89
+ tree ,
90
+ ** rate_pars ,
91
+ ** motif_pars ,
92
+ ) # add global rate parameters to the edges
69
93
70
94
return tree
71
95
@@ -101,7 +125,7 @@ def build_tree(
101
125
seqs = [str (seq ) for seq in aln .iter_seqs (names )]
102
126
103
127
yaml_result = yaml .safe_load (iq_build_tree (names , seqs , str (model ), rand_seed ))
104
- return _process_tree_yaml (yaml_result , names )
128
+ return _process_tree_yaml (yaml_result , names , model )
105
129
106
130
107
131
def fit_tree (
@@ -142,4 +166,4 @@ def fit_tree(
142
166
yaml_result = yaml .safe_load (
143
167
iq_fit_tree (names , seqs , str (model ), newick , rand_seed ),
144
168
)
145
- return _process_tree_yaml (yaml_result , names )
169
+ return _process_tree_yaml (yaml_result , names , model )
0 commit comments