Skip to content

Refactor batch sample matching #954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 53 additions & 51 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@
import json
import logging
import os.path
import pickle
import random
import re
import string
import sys
import tempfile
import time
import unittest
import unittest.mock as mock

Expand Down Expand Up @@ -1366,55 +1364,6 @@ def test_equivalance(self):
assert ts1.equals(ts2, ignore_provenance=True)


class TestResume:
def count_paths(self, match_data_dir):
path_count = 0
for filename in os.listdir(match_data_dir):
with open(os.path.join(match_data_dir, filename), "rb") as f:
stored_data = pickle.load(f)
path_count += len(stored_data.results)
return path_count

def test_equivalance(self, tmpdir):
ts = msprime.simulate(5, mutation_rate=2, recombination_rate=2, random_seed=2)
sample_data = tsinfer.SampleData.from_tree_sequence(ts)
ancestor_data = tsinfer.generate_ancestors(sample_data)
ancestor_ts = tsinfer.match_ancestors(sample_data, ancestor_data)
final_ts1 = tsinfer.match_samples(
sample_data, ancestor_ts, match_data_dir=tmpdir
)
assert self.count_paths(tmpdir) == 5
final_ts2 = tsinfer.match_samples(
sample_data, ancestor_ts, match_data_dir=tmpdir
)
final_ts1.tables.assert_equals(final_ts2.tables, ignore_provenance=True)

def test_cache_used_by_timing(self, tmpdir):

ts = msprime.sim_ancestry(
100, recombination_rate=1, sequence_length=1000, random_seed=42
)
ts = msprime.sim_mutations(
ts, rate=1, random_seed=42, model=msprime.InfiniteSites()
)
sample_data = tsinfer.SampleData.from_tree_sequence(ts)
ancestor_data = tsinfer.generate_ancestors(sample_data)
ancestor_ts = tsinfer.match_ancestors(sample_data, ancestor_data)
t = time.time()
final_ts1 = tsinfer.match_samples(
sample_data, ancestor_ts, match_data_dir=tmpdir
)
time1 = time.time() - t
assert self.count_paths(tmpdir) == 200
t = time.time()
final_ts2 = tsinfer.match_samples(
sample_data, ancestor_ts, match_data_dir=tmpdir
)
time2 = time.time() - t
assert time2 < time1
final_ts1.tables.assert_equals(final_ts2.tables, ignore_provenance=True)


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
class TestBatchAncestorMatching:
def test_equivalance(self, tmp_path, tmpdir):
Expand Down Expand Up @@ -1567,6 +1516,59 @@ def test_errors(self, tmp_path, tmpdir):
tsinfer.match_ancestors_batch_groups(tmpdir / "work", 2, 3)


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
class TestBatchSampleMatching:
def test_match_samples_batch(self, tmp_path, tmpdir):
mat_sd, mask_sd, _, _ = tsutil.make_materialized_and_masked_sampledata(
tmp_path, tmpdir
)
mat_ancestors = tsinfer.generate_ancestors(mat_sd)
mask_ancestors = tsinfer.generate_ancestors(mask_sd)
mat_anc_ts = tsinfer.match_ancestors(mat_sd, mat_ancestors)
mask_anc_ts = tsinfer.match_ancestors(mask_sd, mask_ancestors)
mat_anc_ts.dump(tmpdir / "mat_anc.trees")
mask_anc_ts.dump(tmpdir / "mask_anc.trees")

mat_wd = tsinfer.match_samples_batch_init(
work_dir=tmpdir / "working_mat",
sample_data_path=mat_sd.path,
ancestral_allele="variant_ancestral_allele",
ancestor_ts_path=tmpdir / "mat_anc.trees",
min_work_per_job=1,
max_num_partitions=10,
)
for i in range(mat_wd.num_partitions):
tsinfer.match_samples_batch_partition(
work_dir=tmpdir / "working_mat",
partition_index=i,
)
mat_ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working_mat")

mask_wd = tsinfer.match_samples_batch_init(
work_dir=tmpdir / "working_mask",
sample_data_path=mask_sd.path,
ancestral_allele="variant_ancestral_allele",
ancestor_ts_path=tmpdir / "mask_anc.trees",
min_work_per_job=1,
max_num_partitions=10,
site_mask="variant_mask_foobar",
sample_mask="samples_mask_foobar",
)
for i in range(mask_wd.num_partitions):
tsinfer.match_samples_batch_partition(
work_dir=tmpdir / "working_mask",
partition_index=i,
)
mask_ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working_mask")

mask_ts = tsinfer.match_samples(mask_sd, mask_anc_ts)
mat_ts = tsinfer.match_samples(mat_sd, mat_anc_ts)

mat_ts.tables.assert_equals(mask_ts.tables, ignore_timestamps=True)
mask_ts.tables.assert_equals(mask_ts_batch.tables, ignore_timestamps=True)
mask_ts_batch.tables.assert_equals(mat_ts_batch.tables, ignore_timestamps=True)


class TestAncestorGeneratorsEquivalant:
"""
Tests for the ancestor generation process.
Expand Down
113 changes: 0 additions & 113 deletions tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
Tests for the data files.
"""
import json
import os
import pickle
import sys
import tempfile

Expand Down Expand Up @@ -756,114 +754,3 @@ def test_empty_alleles_not_at_end(self, tmp_path):
samples = tsinfer.VariantData(path, "variant_ancestral_allele")
with pytest.raises(ValueError, match="Empty alleles must be at the end"):
tsinfer.infer(samples)


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
class TestSgkitMatchSamplesToDisk:
@pytest.mark.parametrize("slice", [(0, 6), (0, 0), (0, 3), (12, 15)])
def test_match_samples_to_disk_write(self, slice, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
ancestors = tsinfer.generate_ancestors(samples)
anc_ts = tsinfer.match_ancestors(samples, ancestors)
tsinfer.match_samples_slice_to_disk(
samples, anc_ts, slice, tmpdir / "samples.pkl"
)
stored = pickle.load(open(tmpdir / "samples.pkl", "rb"))
assert stored.group_id == "samples"
assert len(stored.results) == slice[1] - slice[0]
for i, (s, m) in enumerate(stored.results.items()):
assert s == slice[0] + i
assert isinstance(m, tsinfer.inference.MatchResult)

def test_match_samples_to_disk_slice_error(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
ancestors = tsinfer.generate_ancestors(samples)
anc_ts = tsinfer.match_ancestors(samples, ancestors)
with pytest.raises(
ValueError, match="Samples slice must be a multiple of ploidy"
):
tsinfer.match_samples_slice_to_disk(
samples, anc_ts, (0, 1), tmpdir / "test.path"
)

def test_match_samples_to_disk_full(self, tmp_path, tmpdir):
match_data_dir = tmpdir / "match_data"
os.mkdir(match_data_dir)
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
ancestors = tsinfer.generate_ancestors(samples)
anc_ts = tsinfer.match_ancestors(samples, ancestors)
ts = tsinfer.match_samples(samples, anc_ts)
start_index = 0
while start_index < ts.num_samples:
end_index = min(start_index + 6, ts.num_samples)
tsinfer.match_samples_slice_to_disk(
samples,
anc_ts,
(start_index, end_index),
match_data_dir / f"test-{start_index}.pkl",
)
start_index = end_index
batch_ts = tsinfer.match_samples(
samples, anc_ts, match_data_dir=str(match_data_dir)
)
ts.tables.assert_equals(batch_ts.tables, ignore_provenance=True)

(match_data_dir / "test-6.pkl").copy(match_data_dir / "test-6-copy.pkl")
with pytest.raises(ValueError, match="Duplicate sample index 6"):
tsinfer.match_samples(samples, anc_ts, match_data_dir=str(match_data_dir))

os.remove(match_data_dir / "test-6.pkl")
os.remove(match_data_dir / "test-6-copy.pkl")
with pytest.raises(ValueError, match="index 6 not found"):
tsinfer.match_samples(samples, anc_ts, match_data_dir=str(match_data_dir))

def test_match_samples_to_disk_with_mask(self, tmp_path, tmpdir):
mat_sd, mask_sd, _, _ = tsutil.make_materialized_and_masked_sampledata(
tmp_path, tmpdir
)
mat_data_dir = tmpdir / "mat_data"
os.mkdir(mat_data_dir)
mask_data_dir = tmpdir / "mask_data"
os.mkdir(mask_data_dir)
mat_ancestors = tsinfer.generate_ancestors(mat_sd)
mask_ancestors = tsinfer.generate_ancestors(mask_sd)
mat_anc_ts = tsinfer.match_ancestors(mat_sd, mat_ancestors)
mask_anc_ts = tsinfer.match_ancestors(mask_sd, mask_ancestors)
start_index = 0
while start_index < mat_sd.num_samples:
end_index = min(start_index + 6, mat_sd.num_samples)
tsinfer.match_samples_slice_to_disk(
mat_sd,
mat_anc_ts,
(start_index, end_index),
mat_data_dir / f"test-mat-{start_index}.path",
)
start_index = end_index

mat_ts_disk = tsinfer.match_samples(
mat_sd, mat_anc_ts, match_data_dir=str(mat_data_dir)
)

start_index = 0
while start_index < mask_sd.num_samples:
end_index = min(start_index + 6, mask_sd.num_samples)
tsinfer.match_samples_slice_to_disk(
mask_sd,
mask_anc_ts,
(start_index, end_index),
mask_data_dir / f"test-mask-{start_index}.path",
)
start_index = end_index
mask_ts_disk = tsinfer.match_samples(
mask_sd, mask_anc_ts, match_data_dir=str(mask_data_dir)
)

mask_ts = tsinfer.match_samples(mask_sd, mask_anc_ts)
mat_ts = tsinfer.match_samples(mat_sd, mat_anc_ts)

mat_ts.tables.assert_equals(mask_ts.tables, ignore_timestamps=True)
mask_ts.tables.assert_equals(mask_ts_disk.tables, ignore_timestamps=True)
mask_ts_disk.tables.assert_equals(mat_ts_disk.tables, ignore_timestamps=True)
Loading
Loading