Skip to content

MAINT/API: Fix Table references in latest version of cogent3, simplify API of model for build and fit tree apps #201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "piqtree"
dependencies = ["cogent3>=2025.3.22a2", "pyyaml", "requests"]
dependencies = ["cogent3>=2025.5.8a2", "pyyaml", "requests"]
requires-python = ">=3.11, <3.14"

authors = [{name="Gavin Huttley"}, {name="Robert McArthur"}, {name="Bui Quang Minh "}, {name="Richard Morris"}, {name="Thomas Wong"}]
Expand Down
24 changes: 4 additions & 20 deletions src/piqtree/_app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,13 @@ class piqtree_phylo:
@extend_docstring_from(build_tree)
def __init__(
self,
submod_type: str,
freq_type: str | None = None,
rate_model: str | None = None,
model: Model | str,
*,
invariant_sites: bool = False,
rand_seed: int | None = None,
bootstrap_reps: int | None = None,
num_threads: int | None = None,
) -> None:
self._model = Model(
submod_type=submod_type,
invariant_sites=invariant_sites,
rate_model=rate_model,
freq_type=freq_type,
)
self._model = model
self._rand_seed = rand_seed
self._bootstrap_reps = bootstrap_reps
self._num_threads = num_threads
Expand All @@ -61,21 +53,13 @@ class piqtree_fit:
def __init__(
self,
tree: cogent3.PhyloNode,
submod_type: str,
freq_type: str | None = None,
rate_model: str | None = None,
model: Model | str,
*,
rand_seed: int | None = None,
num_threads: int | None = None,
invariant_sites: bool = False,
) -> None:
self._tree = tree
self._model = Model(
submod_type=submod_type,
invariant_sites=invariant_sites,
rate_model=rate_model,
freq_type=freq_type,
)
self._model = model
self._rand_seed = rand_seed
self._num_threads = num_threads

Expand Down
16 changes: 11 additions & 5 deletions src/piqtree/iqtree/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from piqtree.exceptions import ParseIqTreeError
from piqtree.iqtree._decorator import iqtree_func
from piqtree.model import DnaModel, Model
from piqtree.model import DnaModel, Model, make_model

iq_build_tree = iqtree_func(iq_build_tree, hide_files=True)
iq_fit_tree = iqtree_func(iq_fit_tree, hide_files=True)
Expand Down Expand Up @@ -196,7 +196,7 @@ def _process_tree_yaml(

def build_tree(
aln: c3_types.AlignedSeqsType,
model: Model,
model: Model | str,
rand_seed: int | None = None,
bootstrap_replicates: int | None = None,
num_threads: int | None = None,
Expand All @@ -209,7 +209,7 @@ def build_tree(
----------
aln : c3_types.AlignedSeqsType
The sequence alignment.
model : Model
model : Model | str
The substitution model with base frequencies and rate heterogeneity.
rand_seed : int | None, optional
The random seed - 0 or None means no seed, by default None.
Expand All @@ -227,6 +227,9 @@ def build_tree(
The IQ-TREE maximum likelihood tree from the given alignment.

"""
if isinstance(model, str):
model = make_model(model)

if rand_seed is None:
rand_seed = 0 # The default rand_seed in IQ-TREE

Expand Down Expand Up @@ -261,7 +264,7 @@ def build_tree(
def fit_tree(
aln: c3_types.AlignedSeqsType,
tree: cogent3.PhyloNode,
model: Model,
model: Model | str,
rand_seed: int | None = None,
num_threads: int | None = None,
) -> cogent3.PhyloNode:
Expand All @@ -276,7 +279,7 @@ def fit_tree(
The sequence alignment.
tree : cogent3.PhyloNode
The topology to fit branch lengths to.
model : Model
model : Model | str
The substitution model with base frequencies and rate heterogeneity.
rand_seed : int | None, optional
The random seed - 0 or None means no seed, by default None.
Expand All @@ -290,6 +293,9 @@ def fit_tree(
A phylogenetic tree with same given topology fitted with branch lengths.

"""
if isinstance(model, str):
model = make_model(model)

if rand_seed is None:
rand_seed = 0 # The default rand_seed in IQ-TREE

Expand Down
10 changes: 5 additions & 5 deletions src/piqtree/model/_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
from typing import Literal

from cogent3 import _Table, make_table
from cogent3.core.table import Table, make_table

from piqtree.model._freq_type import FreqType
from piqtree.model._rate_type import ALL_BASE_RATE_TYPES, get_description
Expand Down Expand Up @@ -40,7 +40,7 @@ def available_models(
model_type: Literal["dna", "protein"] | None = None,
*,
show_all: bool = True,
) -> _Table:
) -> Table:
"""Return a table showing available substitution models.

Parameters
Expand All @@ -52,7 +52,7 @@ def available_models(

Returns
-------
_Table
Table
Table with all available models.

"""
Expand All @@ -78,7 +78,7 @@ def available_models(
return table


def available_freq_type() -> _Table:
def available_freq_type() -> Table:
"""Return a table showing available freq type options."""
data: dict[str, list[str]] = {"Freq Type": [], "Description": []}

Expand All @@ -89,7 +89,7 @@ def available_freq_type() -> _Table:
return make_table(data=data, title="Available frequency types")


def available_rate_type() -> _Table:
def available_rate_type() -> Table:
"""Return a table showing available rate type options."""
data: dict[str, list[str]] = {"Rate Type": [], "Description": []}

Expand Down
8 changes: 4 additions & 4 deletions tests/test_app/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
from cogent3.core.new_alignment import Alignment

import piqtree
from piqtree import jc_distances
from piqtree import jc_distances, make_model


def test_piqtree_phylo(four_otu: Alignment) -> None:
expected = make_tree("(Human,Chimpanzee,(Rhesus,Mouse));")
app = get_app("piqtree_phylo", submod_type="JC")
app = get_app("piqtree_phylo", model="JC")
got = app(four_otu)
assert expected.same_topology(got)


def test_piqtree_phylo_support(four_otu: Alignment) -> None:
app = get_app("piqtree_phylo", submod_type="JC", bootstrap_reps=1000)
app = get_app("piqtree_phylo", model=make_model("JC"), bootstrap_reps=1000)
got = app(four_otu)
supports = [
node.params.get("support", None)
Expand All @@ -28,7 +28,7 @@ def test_piqtree_fit(three_otu: Alignment) -> None:
tree = make_tree(tip_names=three_otu.names)
app = get_app("model", "JC69", tree=tree)
expected = app(three_otu)
piphylo = get_app("piqtree_fit", tree=tree, submod_type="JC")
piphylo = get_app("piqtree_fit", tree=tree, model="JC")
got = piphylo(three_otu)
assert got.params["lnL"] == pytest.approx(expected.lnL)

Expand Down
12 changes: 11 additions & 1 deletion tests/test_iqtree/test_build_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def check_build_tree(
rate_model: RateModel | None = None,
*,
invariant_sites: bool = False,
coerce_str: bool = False,
) -> None:
expected = make_tree("(Human,Chimpanzee,(Rhesus,Mouse));")

Expand All @@ -32,7 +33,11 @@ def check_build_tree(
rate_model=rate_model,
)

got1 = piqtree.build_tree(four_otu, model, rand_seed=1)
got1 = piqtree.build_tree(
four_otu,
str(model) if coerce_str else model,
rand_seed=1,
)
got1 = got1.unrooted()
# Check topology
assert expected.same_topology(got1.unrooted())
Expand Down Expand Up @@ -61,6 +66,11 @@ def test_lie_build_tree(four_otu: Alignment, dna_model: DnaModel) -> None:
check_build_tree(four_otu, dna_model)


@pytest.mark.parametrize("dna_model", list(DnaModel)[-3:])
def test_str_build_tree(four_otu: Alignment, dna_model: DnaModel) -> None:
check_build_tree(four_otu, dna_model, coerce_str=True)


@pytest.mark.parametrize("dna_model", list(DnaModel)[:5])
@pytest.mark.parametrize("invariant_sites", [False, True])
@pytest.mark.parametrize(
Expand Down
42 changes: 40 additions & 2 deletions tests/test_iqtree/test_fit_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,52 @@ def test_fit_tree(three_otu: Alignment, iq_model: DnaModel, c3_model: str) -> No
app = get_app("model", c3_model, tree=tree_topology)
expected = app(three_otu)

got1 = piqtree.fit_tree(three_otu, tree_topology, Model(iq_model), rand_seed=1)
model = Model(iq_model)

got1 = piqtree.fit_tree(three_otu, tree_topology, model, rand_seed=1)
check_likelihood(got1, expected)
check_motif_probs(got1, expected.tree)
check_rate_parameters(got1, expected.tree)
check_branch_lengths(got1, expected.tree)

# Should be within an approximation for any seed
got2 = piqtree.fit_tree(three_otu, tree_topology, model, rand_seed=None)
check_likelihood(got2, expected)
check_motif_probs(got2, expected.tree)
check_rate_parameters(got2, expected.tree)
check_branch_lengths(got2, expected.tree)


@pytest.mark.parametrize(
("iq_model", "c3_model"),
[
(DnaModel.JC, "JC69"),
(DnaModel.K80, "K80"),
(DnaModel.GTR, "GTR"),
(DnaModel.TN, "TN93"),
(DnaModel.HKY, "HKY85"),
(DnaModel.F81, "F81"),
],
)
def test_fit_tree_str_model(
three_otu: Alignment,
iq_model: DnaModel,
c3_model: str,
) -> None:
tree_topology = make_tree(tip_names=three_otu.names)
app = get_app("model", c3_model, tree=tree_topology)
expected = app(three_otu)

model = str(Model(iq_model))

got1 = piqtree.fit_tree(three_otu, tree_topology, model, rand_seed=1)
check_likelihood(got1, expected)
check_motif_probs(got1, expected.tree)
check_rate_parameters(got1, expected.tree)
check_branch_lengths(got1, expected.tree)

# Should be within an approximation for any seed
got2 = piqtree.fit_tree(three_otu, tree_topology, Model(iq_model), rand_seed=None)
got2 = piqtree.fit_tree(three_otu, tree_topology, model, rand_seed=None)
check_likelihood(got2, expected)
check_motif_probs(got2, expected.tree)
check_rate_parameters(got2, expected.tree)
Expand Down