Skip to content

Commit d815658

Browse files
YapengLangrmcar17
authored andcommitted
ENH: reform rate pars to be consistent with cogent3
#49
1 parent fdc836b commit d815658

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

src/piqtree2/iqtree/_tree.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from piqtree2.exceptions import ParseIqTreeError
1111
from piqtree2.iqtree._decorator import iqtree_func
12-
from piqtree2.model import Model
12+
from piqtree2.model import DnaModel, Model
1313

1414
iq_build_tree = iqtree_func(iq_build_tree, hide_files=True)
1515
iq_fit_tree = iqtree_func(iq_fit_tree, hide_files=True)
@@ -32,7 +32,19 @@ def _intrude_edge_params(tree: cogent3.PhyloNode, **kwargs: dict) -> None:
3232
node.params.update(kwargs)
3333

3434

35-
def _process_tree_yaml(tree_yaml: dict, names: Sequence[str]) -> cogent3.PhyloNode:
35+
def _reform_rate_params(rate_pars: dict, model: Model) -> dict:
36+
if model.substitution_model in {DnaModel.K80, DnaModel.HKY}:
37+
rate_pars = {"kappa": rate_pars["A/G"]}
38+
elif model.substitution_model is DnaModel.TN:
39+
rate_pars = {"kappa_r": rate_pars["A/G"], "kappa_y": rate_pars["C/T"]}
40+
elif model.substitution_model is DnaModel.GTR:
41+
del rate_pars["G/T"]
42+
return rate_pars
43+
44+
45+
def _process_tree_yaml(
46+
tree_yaml: dict, names: Sequence[str], model: Model,
47+
) -> cogent3.PhyloNode:
3648
newick = tree_yaml["PhyloTree"]["newick"]
3749

3850
tree = cogent3.make_tree(newick)
@@ -63,9 +75,21 @@ def _process_tree_yaml(tree_yaml: dict, names: Sequence[str]) -> cogent3.PhyloNo
6375
),
6476
),
6577
}
66-
_intrude_edge_params(
67-
tree, **rate_pars, **motif_pars,
68-
) # add global rate parameters to the edges
78+
79+
if model.substitution_model in {DnaModel.JC, DnaModel.F81}:
80+
_intrude_edge_params(
81+
tree,
82+
**motif_pars, # skip rate_pars since rate parameters are constant in JC and F81
83+
)
84+
else:
85+
rate_pars = _reform_rate_params(
86+
rate_pars, model,
87+
) # reform rate parameters in cogent3 way
88+
_intrude_edge_params(
89+
tree,
90+
**rate_pars,
91+
**motif_pars,
92+
) # add global rate parameters to the edges
6993

7094
return tree
7195

@@ -101,7 +125,7 @@ def build_tree(
101125
seqs = [str(seq) for seq in aln.iter_seqs(names)]
102126

103127
yaml_result = yaml.safe_load(iq_build_tree(names, seqs, str(model), rand_seed))
104-
return _process_tree_yaml(yaml_result, names)
128+
return _process_tree_yaml(yaml_result, names, model)
105129

106130

107131
def fit_tree(
@@ -142,4 +166,4 @@ def fit_tree(
142166
yaml_result = yaml.safe_load(
143167
iq_fit_tree(names, seqs, str(model), newick, rand_seed),
144168
)
145-
return _process_tree_yaml(yaml_result, names)
169+
return _process_tree_yaml(yaml_result, names, model)

0 commit comments

Comments
 (0)