Skip to content

Commit 697d80a

Browse files
authored
Merge pull request #306 from nspope/faster-prior
Faster prior variance calculation
2 parents a47cec7 + 36d91aa commit 697d80a

File tree

8 files changed

+161
-78
lines changed

8 files changed

+161
-78
lines changed

tests/test_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_cached_prior(self):
3737
priors_approxNone.add(10)
3838
assert np.allclose(priors_approx10[10], priors_approxNone[10], equal_nan=True)
3939
# Test when using a bigger n that we're using the precalculated version
40-
priors_approx10.add(100)
40+
priors_approx10.add(100, approximate=True)
4141
assert priors_approx10[100].shape[0] == 100 + 1
4242
priors_approxNone.add(100, approximate=False)
4343
assert priors_approxNone[100].shape[0] == 100 + 1

tests/test_cli.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,6 @@ def test_progress(self, tmp_path, capfd):
265265
desc = (
266266
"Find Node Spans",
267267
"TipCount",
268-
"Calculating Node Age Variances",
269268
"Find Mixture Priors",
270269
"Inside",
271270
"Outside",

tests/test_functions.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,28 +63,14 @@ class TestBasicFunctions:
6363
Test for some of the basic functions used in tsdate
6464
"""
6565

66-
def test_alpha_prob(self):
67-
assert ConditionalCoalescentTimes.m_prob(2, 2, 3) == 1.0
68-
assert ConditionalCoalescentTimes.m_prob(2, 2, 4) == 0.25
69-
7066
def test_tau_expect(self):
7167
assert ConditionalCoalescentTimes.tau_expect(10, 10) == 1.8
7268
assert ConditionalCoalescentTimes.tau_expect(10, 100) == 0.09
7369
assert ConditionalCoalescentTimes.tau_expect(100, 100) == 1.98
7470
assert ConditionalCoalescentTimes.tau_expect(5, 10) == 0.4
7571

76-
def test_tau_squared_conditional(self):
77-
assert ConditionalCoalescentTimes.tau_squared_conditional(
78-
1, 10
79-
) == pytest.approx(4.3981418)
80-
assert ConditionalCoalescentTimes.tau_squared_conditional(
81-
100, 100
82-
) == pytest.approx(4.87890977e-18)
83-
84-
def test_tau_var(self):
85-
assert ConditionalCoalescentTimes.tau_var(2, 2) == 1
86-
assert ConditionalCoalescentTimes.tau_var(10, 20) == pytest.approx(0.0922995960)
87-
assert ConditionalCoalescentTimes.tau_var(50, 50) == pytest.approx(1.15946186)
72+
def test_tau_var_mrca(self):
73+
assert np.isclose(ConditionalCoalescentTimes.tau_var_mrca(50), 1.15946186)
8874

8975
def test_gamma_approx(self):
9076
assert gamma_approx(2, 1) == (4.0, 2.0)
@@ -1880,7 +1866,9 @@ def test_node_selection_param(self):
18801866
def test_sites_time_insideoutside(self):
18811867
ts = utility_functions.two_tree_mutation_ts()
18821868
dated = tsdate.date(ts, mutation_rate=None, population_size=1)
1883-
_, mn_post, _, _, eps, _, _ = get_dates(ts, mutation_rate=None, population_size=1)
1869+
_, mn_post, _, _, eps, _, _ = get_dates(
1870+
ts, mutation_rate=None, population_size=1
1871+
)
18841872
assert np.array_equal(
18851873
mn_post[ts.tables.mutations.node],
18861874
tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0),

tests/test_priors.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import pytest
2929
import utility_functions
3030

31+
from tsdate.prior import conditional_coalescent_variance
3132
from tsdate.prior import ConditionalCoalescentTimes
3233
from tsdate.prior import create_timepoints
3334
from tsdate.prior import PriorParams
@@ -73,8 +74,8 @@ def test_mixture_expect_and_var(self, logwt):
7374
mean2, var2 = priors.mixture_expect_and_var(params, weight_by_log_span=logwt)
7475
assert mean1 == pytest.approx(1 / 3) # 1/N for a cherry
7576
assert var1 == pytest.approx(1 / 9)
76-
assert mean1 == mean2
77-
assert var1 == var2
77+
assert np.isclose(mean1, mean2)
78+
assert np.isclose(var1, var2)
7879

7980
def test_mixture_expect_and_var_weight(self):
8081
priors = ConditionalCoalescentTimes(None)
@@ -100,6 +101,12 @@ def test_mixture_expect_and_var_weight(self):
100101
logwt = priors.mixture_expect_and_var(params, weight_by_log_span=True)
101102
assert np.allclose(linwt, logwt)
102103

104+
def test_fast_equals_naive(self):
105+
# test fast recursion against slow but clearly correct version
106+
true = utility_functions.conditional_coalescent_variance(100)
107+
test = conditional_coalescent_variance(100)
108+
np.testing.assert_array_almost_equal(true, test)
109+
103110

104111
class TestSpansBySamples:
105112
def test_repr(self):
@@ -131,3 +138,26 @@ def test_create_timepoints_error(self):
131138
priors.prior_distr = "bad_distr"
132139
with pytest.raises(ValueError, match="must be lognorm or gamma"):
133140
create_timepoints(priors, n_points=3)
141+
142+
143+
class TestUtilityFunctions:
144+
def test_m_prob(self):
145+
assert utility_functions.m_prob(2, 2, 3) == 1.0
146+
assert utility_functions.m_prob(2, 2, 4) == 0.25
147+
148+
def test_tau_expect(self):
149+
assert utility_functions.tau_expect(10, 10) == 1.8
150+
assert utility_functions.tau_expect(10, 100) == 0.09
151+
assert utility_functions.tau_expect(100, 100) == 1.98
152+
assert utility_functions.tau_expect(5, 10) == 0.4
153+
154+
def test_tau_squared_conditional(self):
155+
assert np.isclose(utility_functions.tau_squared_conditional(1, 10), 4.3981418)
156+
assert np.isclose(
157+
utility_functions.tau_squared_conditional(100, 100), 4.87890977e-18
158+
)
159+
160+
def test_tau_var(self):
161+
assert utility_functions.tau_var(2, 2) == 1
162+
assert np.isclose(utility_functions.tau_var(10, 20), 0.0922995960)
163+
assert np.isclose(utility_functions.tau_var(50, 50), 1.15946186)

tests/utility_functions.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import msprime
2828
import numpy as np
2929
import tskit
30+
from scipy.special import comb
3031

3132

3233
def add_grand_mrca(ts):
@@ -1066,3 +1067,51 @@ def truncate_ts_samples(ts, average_span, random_seed, min_span=5):
10661067
filter_sites=False,
10671068
keep_unary=True,
10681069
)
1070+
1071+
1072+
def m_prob(m, i, n):
1073+
"""
1074+
Corollary 2 in Wiuf and Donnelly (1999). Probability of one
1075+
ancestor to entire sample at time tau
1076+
"""
1077+
return (comb(n - m - 1, i - 2, exact=True) * comb(m, 2, exact=True)) / comb(
1078+
n, i + 1, exact=True
1079+
)
1080+
1081+
1082+
def tau_expect(i, n):
1083+
if i == n:
1084+
return 2 * (1 - (1 / n))
1085+
else:
1086+
return (i - 1) / n
1087+
1088+
1089+
def tau_squared_conditional(m, n):
1090+
"""
1091+
Gives expectation of tau squared conditional on m
1092+
Equation (10) from Wiuf and Donnelly (1999).
1093+
"""
1094+
t_sum = np.sum(1 / np.arange(m, n + 1) ** 2)
1095+
return 8 * t_sum + (8 / n) - (8 / m) - (8 / (n * m))
1096+
1097+
1098+
def tau_var(i, n):
1099+
"""
1100+
For the last coalesence (n=2), calculate the Tmrca of the whole sample
1101+
"""
1102+
if i == n:
1103+
value = np.arange(2, n + 1)
1104+
var = np.sum(1 / ((value**2) * ((value - 1) ** 2)))
1105+
return np.abs(4 * var)
1106+
elif i == 0:
1107+
return 0.0
1108+
else:
1109+
tau_square_sum = 0
1110+
for m in range(2, n - i + 2):
1111+
tau_square_sum += m_prob(m, i, n) * tau_squared_conditional(m, n)
1112+
return np.abs((tau_expect(i, n) ** 2) - (tau_square_sum))
1113+
1114+
1115+
def conditional_coalescent_variance(n_tips):
1116+
"""Variance calculation for prior, slow but clear version"""
1117+
return np.array([tau_var(i, n_tips) for i in range(n_tips + 1)])

tsdate/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,9 +1413,9 @@ def get_dates(
14131413
raise NotImplementedError("Samples must all be at time 0")
14141414
fixed_nodes = set(tree_sequence.samples())
14151415

1416-
# Default to not creating approximate priors unless ts has > 1000 samples
1416+
# Default to not creating approximate priors unless ts has > 20000 samples
14171417
approx_priors = False
1418-
if tree_sequence.num_samples > 1000:
1418+
if tree_sequence.num_samples > 20000:
14191419
approx_priors = True
14201420

14211421
if priors is None:
@@ -1602,9 +1602,9 @@ def variational_dates(
16021602
"Ignoring the oldes root is not implemented in variational dating"
16031603
)
16041604

1605-
# Default to not creating approximate priors unless ts has > 1000 samples
1605+
# Default to not creating approximate priors unless ts has > 20000 samples
16061606
approx_priors = False
1607-
if tree_sequence.num_samples > 1000:
1607+
if tree_sequence.num_samples > 20000:
16081608
approx_priors = True
16091609

16101610
if priors is None:

tsdate/demography.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def to_coalescent_timescale(self, time_ago):
166166
)
167167
return coalescent_time_ago
168168

169+
# TODO: multiprecision implementation -- remove at some point
170+
169171
# @staticmethod
170172
# def _Gamma(z, a=0, b=np.inf):
171173
# """

0 commit comments

Comments
 (0)