Skip to content

Commit cf5269f

Browse files
Rough first-pass at site-divmat
1 parent eda49cf commit cf5269f

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

python/tests/test_divmat.py

+65
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
"""
2323
Test cases for divergence matrix based pairwise stats
2424
"""
25+
import collections
26+
2527
import msprime
2628
import numpy as np
2729
import pytest
@@ -258,6 +260,60 @@ def stats_api_divergence_matrix(ts, windows=None, samples=None):
258260
return out
259261

260262

263+
def rootward_path(tree, u, v):
264+
while u != v:
265+
yield u
266+
u = tree.parent(u)
267+
268+
269+
def site_divergence_matrix_naive(ts, windows=None, samples=None):
270+
windows_specified = windows is not None
271+
windows = [0, ts.sequence_length] if windows is None else windows
272+
num_windows = len(windows) - 1
273+
samples = ts.samples() if samples is None else samples
274+
275+
n = len(samples)
276+
D = np.zeros((num_windows, n, n))
277+
tree = tskit.Tree(ts)
278+
for i in range(num_windows):
279+
left = windows[i]
280+
right = windows[i + 1]
281+
tree.seek(left)
282+
# Iterate over the trees in this window
283+
while tree.interval.left < right and tree.index != -1:
284+
span_left = max(tree.interval.left, left)
285+
span_right = min(tree.interval.right, right)
286+
mutations_per_node = collections.Counter()
287+
for site in tree.sites():
288+
if span_left <= site.position < span_right:
289+
for mutation in site.mutations:
290+
mutations_per_node[mutation.node] += 1
291+
for j in range(n):
292+
u = samples[j]
293+
for k in range(j + 1, n):
294+
v = samples[k]
295+
w = tree.mrca(u, v)
296+
if w != tskit.NULL:
297+
wu = w
298+
wv = w
299+
else:
300+
wu = local_root(tree, u)
301+
wv = local_root(tree, v)
302+
du = sum(mutations_per_node[x] for x in rootward_path(tree, u, wu))
303+
dv = sum(mutations_per_node[x] for x in rootward_path(tree, v, wv))
304+
# NOTE: we're just accumulating the raw mutation counts, not
305+
# multiplying by span
306+
D[i, j, k] += du + dv
307+
tree.next()
308+
# Fill out symmetric triangle in the matrix
309+
for j in range(n):
310+
for k in range(j + 1, n):
311+
D[i, k, j] = D[i, j, k]
312+
if not windows_specified:
313+
D = D[0]
314+
return D
315+
316+
261317
# NOTE: the internal_checks argument is left over from an older version that
262318
# used a more complex algorithm - not removing for now so we can reuse the
263319
# tests that were designed to exercise the code paths of that algorithm.
@@ -696,6 +752,15 @@ def test_threads_windows(self, ts):
696752
self.check(ts, num_threads=5, windows=windows)
697753

698754

755+
class TestSiteDivergence:
756+
def test_simulation_example(self):
757+
ts = msprime.sim_ancestry(4, sequence_length=100, random_seed=2)
758+
ts = msprime.sim_mutations(ts, rate=1)
759+
print(ts.num_mutations)
760+
D1 = site_divergence_matrix_naive(ts)
761+
print(D1)
762+
763+
699764
class TestThreadsNoWindows:
700765
def check(self, ts, num_threads, samples=None):
701766
D1 = ts.divergence_matrix(num_threads=0, samples=samples)

0 commit comments

Comments
 (0)