|
22 | 22 | """
|
23 | 23 | Test cases for divergence matrix based pairwise stats
|
24 | 24 | """
|
| 25 | +import collections |
| 26 | + |
25 | 27 | import msprime
|
26 | 28 | import numpy as np
|
27 | 29 | import pytest
|
@@ -258,6 +260,60 @@ def stats_api_divergence_matrix(ts, windows=None, samples=None):
|
258 | 260 | return out
|
259 | 261 |
|
260 | 262 |
|
| 263 | +def rootward_path(tree, u, v): |
| 264 | + while u != v: |
| 265 | + yield u |
| 266 | + u = tree.parent(u) |
| 267 | + |
| 268 | + |
| 269 | +def site_divergence_matrix_naive(ts, windows=None, samples=None): |
| 270 | + windows_specified = windows is not None |
| 271 | + windows = [0, ts.sequence_length] if windows is None else windows |
| 272 | + num_windows = len(windows) - 1 |
| 273 | + samples = ts.samples() if samples is None else samples |
| 274 | + |
| 275 | + n = len(samples) |
| 276 | + D = np.zeros((num_windows, n, n)) |
| 277 | + tree = tskit.Tree(ts) |
| 278 | + for i in range(num_windows): |
| 279 | + left = windows[i] |
| 280 | + right = windows[i + 1] |
| 281 | + tree.seek(left) |
| 282 | + # Iterate over the trees in this window |
| 283 | + while tree.interval.left < right and tree.index != -1: |
| 284 | + span_left = max(tree.interval.left, left) |
| 285 | + span_right = min(tree.interval.right, right) |
| 286 | + mutations_per_node = collections.Counter() |
| 287 | + for site in tree.sites(): |
| 288 | + if span_left <= site.position < span_right: |
| 289 | + for mutation in site.mutations: |
| 290 | + mutations_per_node[mutation.node] += 1 |
| 291 | + for j in range(n): |
| 292 | + u = samples[j] |
| 293 | + for k in range(j + 1, n): |
| 294 | + v = samples[k] |
| 295 | + w = tree.mrca(u, v) |
| 296 | + if w != tskit.NULL: |
| 297 | + wu = w |
| 298 | + wv = w |
| 299 | + else: |
| 300 | + wu = local_root(tree, u) |
| 301 | + wv = local_root(tree, v) |
| 302 | + du = sum(mutations_per_node[x] for x in rootward_path(tree, u, wu)) |
| 303 | + dv = sum(mutations_per_node[x] for x in rootward_path(tree, v, wv)) |
| 304 | + # NOTE: we're just accumulating the raw mutation counts, not |
| 305 | + # multiplying by span |
| 306 | + D[i, j, k] += du + dv |
| 307 | + tree.next() |
| 308 | + # Fill out symmetric triangle in the matrix |
| 309 | + for j in range(n): |
| 310 | + for k in range(j + 1, n): |
| 311 | + D[i, k, j] = D[i, j, k] |
| 312 | + if not windows_specified: |
| 313 | + D = D[0] |
| 314 | + return D |
| 315 | + |
| 316 | + |
261 | 317 | # NOTE: the internal_checks argument is left over from an older version that
|
262 | 318 | # used a more complex algorithm - not removing for now so we can reuse the
|
263 | 319 | # tests that were designed to exercise the code paths of that algorithm.
|
@@ -696,6 +752,15 @@ def test_threads_windows(self, ts):
|
696 | 752 | self.check(ts, num_threads=5, windows=windows)
|
697 | 753 |
|
698 | 754 |
|
| 755 | +class TestSiteDivergence: |
| 756 | + def test_simulation_example(self): |
| 757 | + ts = msprime.sim_ancestry(4, sequence_length=100, random_seed=2) |
| 758 | + ts = msprime.sim_mutations(ts, rate=1) |
| 759 | + print(ts.num_mutations) |
| 760 | + D1 = site_divergence_matrix_naive(ts) |
| 761 | + print(D1) |
| 762 | + |
| 763 | + |
699 | 764 | class TestThreadsNoWindows:
|
700 | 765 | def check(self, ts, num_threads, samples=None):
|
701 | 766 | D1 = ts.divergence_matrix(num_threads=0, samples=samples)
|
|
0 commit comments