Skip to content

Commit d645c63

Browse files
committed
Change inference tests to check missing data imputation
1 parent 5974528 commit d645c63

File tree

1 file changed

+34
-98
lines changed

1 file changed

+34
-98
lines changed

tests/test_inference.py

Lines changed: 34 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434

3535
import tsinfer
3636
import tsinfer.eval_util as eval_util
37-
import tsutil
3837

3938

4039
def get_random_data_example(num_samples, num_sites, seed=42, num_states=2):
@@ -2155,122 +2154,59 @@ def verify_example(self, full_subset, samples, ancestors, path_compression):
21552154
ancestors_ts = augmented_ancestors
21562155

21572156

2158-
class TestMissingSampleDataInference(unittest.TestCase):
2157+
class TestMissingDataImputed(unittest.TestCase):
21592158
"""
2160-
Test that we can infer sites with tskit.MISSING_DATA, using both the PY and C engines
2159+
Test that sites with tskit.MISSING_DATA are imputed, using both the PY and C engines
21612160
"""
2162-
def test_missing_haplotypes(self):
2161+
def test_missing_site(self):
21632162
u = tskit.MISSING_DATA
21642163
sites_by_samples = np.array([
2165-
[u, u, u, u],
2166-
[u, 1, 0, u],
2167-
[u, 0, 1, u],
2168-
[u, 1, 1, u]
2164+
[u, u, u, u, u], # Site 0
2165+
[1, 1, 1, 0, 1], # Site 1
2166+
[1, 0, 1, 1, 0], # Site 2
2167+
[0, 0, 0, 1, 0], # Site 3
21692168
], dtype=np.int8)
2169+
expected = sites_by_samples.copy()
2170+
expected[0, :] = [0, 0, 0, 0, 0]
21702171
with tsinfer.SampleData() as sample_data:
2171-
for col in range(sites_by_samples.shape[1]):
2172-
sample_data.add_site(col, sites_by_samples[:, col])
2172+
for row in range(sites_by_samples.shape[0]):
2173+
sample_data.add_site(row, sites_by_samples[row, :])
21732174
for e in [tsinfer.PY_ENGINE, tsinfer.C_ENGINE]:
21742175
ts = tsinfer.infer(sample_data, engine=e)
2175-
self.assertTrue(np.all(sites_by_samples == ts.genotype_matrix().T))
2176+
self.assertEquals(ts.num_trees, 2)
2177+
self.assertTrue(np.all(expected == ts.genotype_matrix()))
21762178

2177-
def test_samples_missing_inference_sites(self):
2179+
def test_missing_haplotype(self):
21782180
u = tskit.MISSING_DATA
21792181
sites_by_samples = np.array([
2180-
[1, 0, 0, u],
2181-
[1, 0, 0, u],
2182-
[0, 1, 1, 1],
2183-
[u, u, u, 1]], dtype=np.int8)
2182+
[u, 1, 1, 1, 0], # Site 0
2183+
[u, 1, 1, 0, 0], # Site 1
2184+
[u, 0, 0, 1, 0], # Site 2
2185+
[u, 0, 1, 1, 0], # Site 3
2186+
], dtype=np.int8)
2187+
expected = sites_by_samples.copy()
2188+
expected[:, 0] = [0, 0, 0, 0]
21842189
with tsinfer.SampleData() as sample_data:
2185-
for col in range(sites_by_samples.shape[1]):
2186-
sample_data.add_site(col, sites_by_samples[:, col])
2190+
for row in range(sites_by_samples.shape[0]):
2191+
sample_data.add_site(row, sites_by_samples[row, :])
21872192
for e in [tsinfer.PY_ENGINE, tsinfer.C_ENGINE]:
21882193
ts = tsinfer.infer(sample_data, engine=e)
2189-
self.assertTrue(np.all(sites_by_samples == ts.genotype_matrix().T))
2194+
self.assertEquals(ts.num_trees, 2)
2195+
self.assertTrue(np.all(expected == ts.genotype_matrix()))
21902196

2191-
def test_samples_imputed_noninference_sites(self):
2197+
def test_missing_inference_sites(self):
21922198
u = tskit.MISSING_DATA
21932199
sites_by_samples = np.array([
2194-
[0, u, 1, 1], # Sample A
2195-
[0, u, u, 1], # Sample B
2196-
[0, u, 1, 1], # Sample C
2197-
[1, u, 0, 0] # Sample D
2200+
[u, 1, 1, 1, 0], # Site 0
2201+
[1, 1, u, 0, 0], # Site 1
21982202
], dtype=np.int8)
2199-
infer_site = [None, None, False, None]
2200-
# Sites all compatible with a single tree: ((A,B,C),D);
2201-
# Site 2 (all missing) should be imputed to all 0
2202-
# Site 3 should be imputed to have 1 at the missing site
2203+
expected = sites_by_samples.copy()
2204+
expected[:, 0] = [1, 1]
2205+
expected[:, 2] = [1, 0]
22032206
with tsinfer.SampleData() as sample_data:
2204-
for col in range(sites_by_samples.shape[1]):
2205-
sample_data.add_site(
2206-
col, sites_by_samples[:, col], inference=infer_site[col])
2207+
for row in range(sites_by_samples.shape[0]):
2208+
sample_data.add_site(row, sites_by_samples[row, :])
22072209
for e in [tsinfer.PY_ENGINE, tsinfer.C_ENGINE]:
22082210
ts = tsinfer.infer(sample_data, engine=e)
22092211
self.assertEquals(ts.num_trees, 1)
2210-
self.assertTrue(np.all(ts.genotype_matrix().T[:, 0] ==
2211-
sites_by_samples[:, 0]))
2212-
self.assertTrue(np.all(ts.genotype_matrix().T[:, 1] ==
2213-
np.array([0, 0, 0, 0])))
2214-
self.assertTrue(np.all(ts.genotype_matrix().T[:, 2] ==
2215-
np.array([1, 1, 1, 0])))
2216-
self.assertTrue(np.all(ts.genotype_matrix().T[:, 3] ==
2217-
sites_by_samples[:, 3]))
2218-
2219-
def test_small_truncated_fragments(self):
2220-
u = tskit.MISSING_DATA
2221-
sites_by_samples = np.array([
2222-
[u, u, u, 1, 1, 0, 1, 1, 1, u],
2223-
[u, u, u, 1, 0, 0, 1, 1, 0, u],
2224-
[u, u, u, 1, 0, 1, 1, 0, 1, u],
2225-
[u, 0, 0, 1, 0, 1, 1, u, u, u],
2226-
[u, 0, 1, 1, 0, 0, 1, u, u, u],
2227-
[u, 1, 1, 0, 0, 0, 0, u, u, u]
2228-
], dtype=np.int8)
2229-
with tsinfer.SampleData() as sample_data:
2230-
for col in range(sites_by_samples.shape[1]):
2231-
sample_data.add_site(col, sites_by_samples[:, col])
2232-
for e in [tsinfer.PY_ENGINE, tsinfer.C_ENGINE]:
2233-
ancestors = tsinfer.generate_ancestors(sample_data, engine=e)
2234-
ancestors_ts = tsinfer.match_ancestors(
2235-
sample_data, ancestors, engine=e, extended_checks=True)
2236-
ts = tsinfer.match_samples(
2237-
sample_data, ancestors_ts, engine=e, extended_checks=True)
2238-
self.assertTrue(1.0 in ts.breakpoints(True)) # End of lft unknown region
2239-
self.assertTrue(3.0 in ts.breakpoints(True)) # End of 1st unknown batch
2240-
self.assertTrue(7.0 in ts.breakpoints(True)) # Start of 2nd unknown batch
2241-
self.assertTrue(9.0 in ts.breakpoints(True)) # Start of rgt unknown region
2242-
for tree in ts.trees():
2243-
for s in ts.samples():
2244-
if tree.interval[1] <= 1:
2245-
self.assertTrue(tree.parent(s) == tskit.NULL)
2246-
elif tree.interval[1] <= 3:
2247-
if s in [0, 1, 2]: # Missing data at pos <=3 for these samples
2248-
self.assertTrue(tree.parent(s) == tskit.NULL)
2249-
else:
2250-
self.assertTrue(tree.parent(s) != tskit.NULL)
2251-
elif tree.interval[0] >= 9:
2252-
self.assertTrue(tree.parent(s) == tskit.NULL)
2253-
elif tree.interval[0] >= 7:
2254-
if s in [3, 4, 5]: # Missing data at pos >=7 for these samples
2255-
self.assertTrue(tree.parent(s) == tskit.NULL)
2256-
else:
2257-
self.assertTrue(tree.parent(s) != tskit.NULL)
2258-
2259-
self.assertTrue(np.all(sites_by_samples == ts.genotype_matrix().T))
2260-
2261-
def test_large_truncated_fragments(self):
2262-
"""
2263-
A bit like fragments produced from a sequencer
2264-
"""
2265-
ts = msprime.simulate(
2266-
10, Ne=1e2, length=400, recombination_rate=1e-4, mutation_rate=2e-4,
2267-
random_seed=1)
2268-
truncated_ts = tsutil.truncate_ts_samples(ts, average_span=200, random_seed=123)
2269-
sd = tsinfer.SampleData.from_tree_sequence(truncated_ts, use_times=False)
2270-
# Cannot use the normal `simplify` as this removes parts of the TS where only
2271-
# one sample is connected to the root (& the other samples have missing data)
2272-
ts_inferred = tsinfer.infer(sd, engine="P")
2273-
non_missing = truncated_ts.genotype_matrix() != tskit.MISSING_DATA
2274-
self.assertTrue(
2275-
np.all(ts_inferred.genotype_matrix()[non_missing] ==
2276-
truncated_ts.genotype_matrix()[non_missing]))
2212+
self.assertTrue(np.all(expected == ts.genotype_matrix()))

0 commit comments

Comments
 (0)