Skip to content

Batch ancestor matching #917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 0 additions & 65 deletions tests/test_dask.py

This file was deleted.

196 changes: 153 additions & 43 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import json
import logging
import os.path
import pickle
import random
import re
import string
Expand All @@ -33,7 +34,6 @@
import unittest
import unittest.mock as mock

import lmdb
import msprime
import numpy as np
import pytest
Expand Down Expand Up @@ -1366,43 +1366,31 @@ def test_equivalance(self):
assert ts1.equals(ts2, ignore_provenance=True)


@pytest.mark.skipif(IS_WINDOWS, reason="Not enough disk space as no sparse files")
class TestResume:
def count_keys(self, lmdb_file):
with lmdb.open(
lmdb_file, subdir=False, map_size=100 * 1024 * 1024 * 1024
) as lmdb_file:
with lmdb_file.begin() as txn:
# Count the number of keys
n_keys = 0
for _ in txn.cursor():
n_keys += 1
return n_keys
def count_paths(self, match_data_dir):
path_count = 0
for filename in os.listdir(match_data_dir):
with open(os.path.join(match_data_dir, filename), "rb") as f:
stored_data = pickle.load(f)
path_count += len(stored_data.results)
return path_count

def test_equivalance(self, tmpdir):
lmdb_file = str(tmpdir / "LMDB")
ts = msprime.simulate(5, mutation_rate=2, recombination_rate=2, random_seed=2)
sample_data = tsinfer.SampleData.from_tree_sequence(ts)
ancestor_data = tsinfer.generate_ancestors(sample_data)
ancestor_ts1 = tsinfer.match_ancestors(
sample_data, ancestor_data, resume_lmdb_file=lmdb_file
)
assert self.count_keys(lmdb_file) == 4
ancestor_ts2 = tsinfer.match_ancestors(
sample_data, ancestor_data, resume_lmdb_file=lmdb_file
)
ancestor_ts1.tables.assert_equals(ancestor_ts2.tables, ignore_provenance=True)
ancestor_ts = tsinfer.match_ancestors(sample_data, ancestor_data)
final_ts1 = tsinfer.match_samples(
sample_data, ancestor_ts1, resume_lmdb_file=lmdb_file
sample_data, ancestor_ts, match_data_dir=tmpdir
)
assert self.count_keys(lmdb_file) == 5
assert self.count_paths(tmpdir) == 5
final_ts2 = tsinfer.match_samples(
sample_data, ancestor_ts1, resume_lmdb_file=lmdb_file
sample_data, ancestor_ts, match_data_dir=tmpdir
)
final_ts1.tables.assert_equals(final_ts2.tables, ignore_provenance=True)

def test_cache_used_by_timing(self, tmpdir):
lmdb_file = str(tmpdir / "LMDB")

ts = msprime.sim_ancestry(
100, recombination_rate=1, sequence_length=1000, random_seed=42
)
Expand All @@ -1411,35 +1399,157 @@ def test_cache_used_by_timing(self, tmpdir):
)
sample_data = tsinfer.SampleData.from_tree_sequence(ts)
ancestor_data = tsinfer.generate_ancestors(sample_data)
t = time.time()
ancestor_ts1 = tsinfer.match_ancestors(
sample_data, ancestor_data, resume_lmdb_file=lmdb_file
)
time1 = time.time() - t
assert self.count_keys(lmdb_file) >= 103
t = time.time()
ancestor_ts2 = tsinfer.match_ancestors(
sample_data, ancestor_data, resume_lmdb_file=lmdb_file
)
ancestor_ts1.tables.assert_equals(ancestor_ts2.tables, ignore_provenance=True)
time2 = time.time() - t
assert time2 < time1 / 2

ancestor_ts = tsinfer.match_ancestors(sample_data, ancestor_data)
t = time.time()
final_ts1 = tsinfer.match_samples(
sample_data, ancestor_ts1, resume_lmdb_file=lmdb_file
sample_data, ancestor_ts, match_data_dir=tmpdir
)
time1 = time.time() - t
assert self.count_keys(lmdb_file) == 104
assert self.count_paths(tmpdir) == 200
t = time.time()
final_ts2 = tsinfer.match_samples(
sample_data, ancestor_ts1, resume_lmdb_file=lmdb_file
sample_data, ancestor_ts, match_data_dir=tmpdir
)
time2 = time.time() - t
assert time2 < time1 / 1.25
final_ts1.tables.assert_equals(final_ts2.tables, ignore_provenance=True)


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
class TestBatchAncestorMatching:
def test_equivalance(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
ancestors = tsinfer.generate_ancestors(
samples, path=str(tmpdir / "ancestors.zarr")
)
metadata = tsinfer.match_ancestors_batch_init(
tmpdir / "work", zarr_path, tmpdir / "ancestors.zarr", 1000
)
for group_index, _ in enumerate(metadata["ancestor_grouping"]):
tsinfer.match_ancestors_batch_groups(
tmpdir / "work", group_index, group_index + 1, 2
)
ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work")
ts2 = tsinfer.match_ancestors(samples, ancestors)
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)

def test_equivalance_many_at_once(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
ancestors = tsinfer.generate_ancestors(
samples, path=str(tmpdir / "ancestors.zarr")
)
metadata = tsinfer.match_ancestors_batch_init(
tmpdir / "work", zarr_path, tmpdir / "ancestors.zarr", 1000
)
tsinfer.match_ancestors_batch_groups(
tmpdir / "work", 0, len(metadata["ancestor_grouping"]) // 2, 2
)
tsinfer.match_ancestors_batch_groups(
tmpdir / "work",
len(metadata["ancestor_grouping"]) // 2,
len(metadata["ancestor_grouping"]),
2,
)
# TODO Check which ones written to disk
ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work")
ts2 = tsinfer.match_ancestors(samples, ancestors)
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)

def test_equivalance_with_partitions(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
ancestors = tsinfer.generate_ancestors(
samples, path=str(tmpdir / "ancestors.zarr")
)
metadata = tsinfer.match_ancestors_batch_init(
tmpdir / "work", zarr_path, tmpdir / "ancestors.zarr", 1000
)
for group_index, group in enumerate(metadata["ancestor_grouping"]):
if group["partitions"] is None:
tsinfer.match_ancestors_batch_groups(
tmpdir / "work", group_index, group_index + 1
)
else:
for p_index, _ in enumerate(group["partitions"]):
tsinfer.match_ancestors_batch_group_partition(
tmpdir / "work", group_index, p_index
)
ts = tsinfer.match_ancestors_batch_group_finalise(
tmpdir / "work", group_index
)
ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work")
ts2 = tsinfer.match_ancestors(samples, ancestors)
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)

def test_max_partitions(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
ancestors = tsinfer.generate_ancestors(
samples, path=str(tmpdir / "ancestors.zarr")
)
metadata = tsinfer.match_ancestors_batch_init(
tmpdir / "work",
zarr_path,
tmpdir / "ancestors.zarr",
10000,
max_num_partitions=2,
)
for group_index, group in enumerate(metadata["ancestor_grouping"]):
if group["partitions"] is None:
tsinfer.match_ancestors_batch_groups(
tmpdir / "work", group_index, group_index + 1
)
else:
assert len(group["partitions"]) <= 2
for p_index, _ in enumerate(group["partitions"]):
tsinfer.match_ancestors_batch_group_partition(
tmpdir / "work", group_index, p_index
)
ts = tsinfer.match_ancestors_batch_group_finalise(
tmpdir / "work", group_index
)
ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work")
ts2 = tsinfer.match_ancestors(samples, ancestors)
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)

def test_errors(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr"))
metadata = tsinfer.match_ancestors_batch_init(
tmpdir / "work", zarr_path, tmpdir / "ancestors.zarr", 1000
)
with pytest.raises(ValueError, match="out of range"):
tsinfer.match_ancestors_batch_groups(tmpdir / "work", -1, 1)
with pytest.raises(ValueError, match="out of range"):
tsinfer.match_ancestors_batch_groups(tmpdir / "work", 0, -1)
with pytest.raises(ValueError, match="must be greater"):
tsinfer.match_ancestors_batch_groups(tmpdir / "work", 5, 4)

with pytest.raises(ValueError, match="has no partitions"):
tsinfer.match_ancestors_batch_group_partition(tmpdir / "work", 0, 1)
last_group = len(metadata["ancestor_grouping"]) - 1
with pytest.raises(ValueError, match="out of range"):
tsinfer.match_ancestors_batch_group_partition(
tmpdir / "work", last_group, 1000
)

# Match a single group to get a ts written to disk
tsinfer.match_ancestors_batch_groups(tmpdir / "work", 0, 2)
assert (tmpdir / "work" / "ancestors_1.trees").exists()

# Modify to change sequence length
ts = tskit.load(str(tmpdir / "work" / "ancestors_1.trees"))
tables = ts.dump_tables()
tables.sequence_length += 1
ts = tables.tree_sequence()
ts.dump(str(tmpdir / "work" / "ancestors_1.trees"))
with pytest.raises(ValueError, match="sequence length is different"):
tsinfer.match_ancestors_batch_groups(tmpdir / "work", 2, 3)


class TestAncestorGeneratorsEquivalant:
"""
Tests for the ancestor generation process.
Expand Down Expand Up @@ -2490,7 +2600,7 @@ def test_treeseq_builder_print_state(self):
matcher_container = tsinfer.AncestorMatcher(
sample_data, ancestor_data, engine=tsinfer.PY_ENGINE
)
matcher_container.match_ancestors()
matcher_container.match_ancestors(matcher_container.group_by_linesweep())
with mock.patch("sys.stdout", new=io.StringIO()) as mockOutput:
matcher_container.tree_sequence_builder.print_state()
# Simply check some text is output
Expand Down
Loading
Loading