|
34 | 34 |
|
35 | 35 | import tsinfer |
36 | 36 | import tsinfer.eval_util as eval_util |
37 | | -import tsutil |
38 | 37 |
|
39 | 38 |
|
40 | 39 | def get_random_data_example(num_samples, num_sites, seed=42, num_states=2): |
@@ -2155,122 +2154,59 @@ def verify_example(self, full_subset, samples, ancestors, path_compression): |
2155 | 2154 | ancestors_ts = augmented_ancestors |
2156 | 2155 |
|
2157 | 2156 |
|
2158 | | -class TestMissingSampleDataInference(unittest.TestCase): |
| 2157 | +class TestMissingDataImputed(unittest.TestCase): |
2159 | 2158 | """ |
2160 | | - Test that we can infer sites with tskit.MISSING_DATA, using both the PY and C engines |
| 2159 | + Test that sites with tskit.MISSING_DATA are imputed, using both the PY and C engines |
2161 | 2160 | """ |
2162 | | - def test_missing_haplotypes(self): |
| 2161 | + def test_missing_site(self): |
2163 | 2162 | u = tskit.MISSING_DATA |
2164 | 2163 | sites_by_samples = np.array([ |
2165 | | - [u, u, u, u], |
2166 | | - [u, 1, 0, u], |
2167 | | - [u, 0, 1, u], |
2168 | | - [u, 1, 1, u] |
| 2164 | + [u, u, u, u, u], # Site 0 |
| 2165 | + [1, 1, 1, 0, 1], # Site 1 |
| 2166 | + [1, 0, 1, 1, 0], # Site 2 |
| 2167 | + [0, 0, 0, 1, 0], # Site 3 |
2169 | 2168 | ], dtype=np.int8) |
| 2169 | + expected = sites_by_samples.copy() |
| 2170 | + expected[0, :] = [0, 0, 0, 0, 0] |
2170 | 2171 | with tsinfer.SampleData() as sample_data: |
2171 | | - for col in range(sites_by_samples.shape[1]): |
2172 | | - sample_data.add_site(col, sites_by_samples[:, col]) |
| 2172 | + for row in range(sites_by_samples.shape[0]): |
| 2173 | + sample_data.add_site(row, sites_by_samples[row, :]) |
2173 | 2174 | for e in [tsinfer.PY_ENGINE, tsinfer.C_ENGINE]: |
2174 | 2175 | ts = tsinfer.infer(sample_data, engine=e) |
2175 | | - self.assertTrue(np.all(sites_by_samples == ts.genotype_matrix().T)) |
| 2176 | + self.assertEquals(ts.num_trees, 2) |
| 2177 | + self.assertTrue(np.all(expected == ts.genotype_matrix())) |
2176 | 2178 |
|
2177 | | - def test_samples_missing_inference_sites(self): |
| 2179 | + def test_missing_haplotype(self): |
2178 | 2180 | u = tskit.MISSING_DATA |
2179 | 2181 | sites_by_samples = np.array([ |
2180 | | - [1, 0, 0, u], |
2181 | | - [1, 0, 0, u], |
2182 | | - [0, 1, 1, 1], |
2183 | | - [u, u, u, 1]], dtype=np.int8) |
| 2182 | + [u, 1, 1, 1, 0], # Site 0 |
| 2183 | + [u, 1, 1, 0, 0], # Site 1 |
| 2184 | + [u, 0, 0, 1, 0], # Site 2 |
| 2185 | + [u, 0, 1, 1, 0], # Site 3 |
| 2186 | + ], dtype=np.int8) |
| 2187 | + expected = sites_by_samples.copy() |
| 2188 | + expected[:, 0] = [0, 0, 0, 0] |
2184 | 2189 | with tsinfer.SampleData() as sample_data: |
2185 | | - for col in range(sites_by_samples.shape[1]): |
2186 | | - sample_data.add_site(col, sites_by_samples[:, col]) |
| 2190 | + for row in range(sites_by_samples.shape[0]): |
| 2191 | + sample_data.add_site(row, sites_by_samples[row, :]) |
2187 | 2192 | for e in [tsinfer.PY_ENGINE, tsinfer.C_ENGINE]: |
2188 | 2193 | ts = tsinfer.infer(sample_data, engine=e) |
2189 | | - self.assertTrue(np.all(sites_by_samples == ts.genotype_matrix().T)) |
| 2194 | + self.assertEquals(ts.num_trees, 2) |
| 2195 | + self.assertTrue(np.all(expected == ts.genotype_matrix())) |
2190 | 2196 |
|
2191 | | - def test_samples_imputed_noninference_sites(self): |
| 2197 | + def test_missing_inference_sites(self): |
2192 | 2198 | u = tskit.MISSING_DATA |
2193 | 2199 | sites_by_samples = np.array([ |
2194 | | - [0, u, 1, 1], # Sample A |
2195 | | - [0, u, u, 1], # Sample B |
2196 | | - [0, u, 1, 1], # Sample C |
2197 | | - [1, u, 0, 0] # Sample D |
| 2200 | + [u, 1, 1, 1, 0], # Site 0 |
| 2201 | + [1, 1, u, 0, 0], # Site 1 |
2198 | 2202 | ], dtype=np.int8) |
2199 | | - infer_site = [None, None, False, None] |
2200 | | - # Sites all compatible with a single tree: ((A,B,C),D); |
2201 | | - # Site 2 (all missing) should be imputed to all 0 |
2202 | | - # Site 3 should be imputed to have 1 at the missing site |
| 2203 | + expected = sites_by_samples.copy() |
| 2204 | + expected[:, 0] = [1, 1] |
| 2205 | + expected[:, 2] = [1, 0] |
2203 | 2206 | with tsinfer.SampleData() as sample_data: |
2204 | | - for col in range(sites_by_samples.shape[1]): |
2205 | | - sample_data.add_site( |
2206 | | - col, sites_by_samples[:, col], inference=infer_site[col]) |
| 2207 | + for row in range(sites_by_samples.shape[0]): |
| 2208 | + sample_data.add_site(row, sites_by_samples[row, :]) |
2207 | 2209 | for e in [tsinfer.PY_ENGINE, tsinfer.C_ENGINE]: |
2208 | 2210 | ts = tsinfer.infer(sample_data, engine=e) |
2209 | 2211 | self.assertEquals(ts.num_trees, 1) |
2210 | | - self.assertTrue(np.all(ts.genotype_matrix().T[:, 0] == |
2211 | | - sites_by_samples[:, 0])) |
2212 | | - self.assertTrue(np.all(ts.genotype_matrix().T[:, 1] == |
2213 | | - np.array([0, 0, 0, 0]))) |
2214 | | - self.assertTrue(np.all(ts.genotype_matrix().T[:, 2] == |
2215 | | - np.array([1, 1, 1, 0]))) |
2216 | | - self.assertTrue(np.all(ts.genotype_matrix().T[:, 3] == |
2217 | | - sites_by_samples[:, 3])) |
2218 | | - |
2219 | | - def test_small_truncated_fragments(self): |
2220 | | - u = tskit.MISSING_DATA |
2221 | | - sites_by_samples = np.array([ |
2222 | | - [u, u, u, 1, 1, 0, 1, 1, 1, u], |
2223 | | - [u, u, u, 1, 0, 0, 1, 1, 0, u], |
2224 | | - [u, u, u, 1, 0, 1, 1, 0, 1, u], |
2225 | | - [u, 0, 0, 1, 0, 1, 1, u, u, u], |
2226 | | - [u, 0, 1, 1, 0, 0, 1, u, u, u], |
2227 | | - [u, 1, 1, 0, 0, 0, 0, u, u, u] |
2228 | | - ], dtype=np.int8) |
2229 | | - with tsinfer.SampleData() as sample_data: |
2230 | | - for col in range(sites_by_samples.shape[1]): |
2231 | | - sample_data.add_site(col, sites_by_samples[:, col]) |
2232 | | - for e in [tsinfer.PY_ENGINE, tsinfer.C_ENGINE]: |
2233 | | - ancestors = tsinfer.generate_ancestors(sample_data, engine=e) |
2234 | | - ancestors_ts = tsinfer.match_ancestors( |
2235 | | - sample_data, ancestors, engine=e, extended_checks=True) |
2236 | | - ts = tsinfer.match_samples( |
2237 | | - sample_data, ancestors_ts, engine=e, extended_checks=True) |
2238 | | - self.assertTrue(1.0 in ts.breakpoints(True)) # End of lft unknown region |
2239 | | - self.assertTrue(3.0 in ts.breakpoints(True)) # End of 1st unknown batch |
2240 | | - self.assertTrue(7.0 in ts.breakpoints(True)) # Start of 2nd unknown batch |
2241 | | - self.assertTrue(9.0 in ts.breakpoints(True)) # Start of rgt unknown region |
2242 | | - for tree in ts.trees(): |
2243 | | - for s in ts.samples(): |
2244 | | - if tree.interval[1] <= 1: |
2245 | | - self.assertTrue(tree.parent(s) == tskit.NULL) |
2246 | | - elif tree.interval[1] <= 3: |
2247 | | - if s in [0, 1, 2]: # Missing data at pos <=3 for these samples |
2248 | | - self.assertTrue(tree.parent(s) == tskit.NULL) |
2249 | | - else: |
2250 | | - self.assertTrue(tree.parent(s) != tskit.NULL) |
2251 | | - elif tree.interval[0] >= 9: |
2252 | | - self.assertTrue(tree.parent(s) == tskit.NULL) |
2253 | | - elif tree.interval[0] >= 7: |
2254 | | - if s in [3, 4, 5]: # Missing data at pos >=7 for these samples |
2255 | | - self.assertTrue(tree.parent(s) == tskit.NULL) |
2256 | | - else: |
2257 | | - self.assertTrue(tree.parent(s) != tskit.NULL) |
2258 | | - |
2259 | | - self.assertTrue(np.all(sites_by_samples == ts.genotype_matrix().T)) |
2260 | | - |
2261 | | - def test_large_truncated_fragments(self): |
2262 | | - """ |
2263 | | - A bit like fragments produced from a sequencer |
2264 | | - """ |
2265 | | - ts = msprime.simulate( |
2266 | | - 10, Ne=1e2, length=400, recombination_rate=1e-4, mutation_rate=2e-4, |
2267 | | - random_seed=1) |
2268 | | - truncated_ts = tsutil.truncate_ts_samples(ts, average_span=200, random_seed=123) |
2269 | | - sd = tsinfer.SampleData.from_tree_sequence(truncated_ts, use_times=False) |
2270 | | - # Cannot use the normal `simplify` as this removes parts of the TS where only |
2271 | | - # one sample is connected to the root (& the other samples have missing data) |
2272 | | - ts_inferred = tsinfer.infer(sd, engine="P") |
2273 | | - non_missing = truncated_ts.genotype_matrix() != tskit.MISSING_DATA |
2274 | | - self.assertTrue( |
2275 | | - np.all(ts_inferred.genotype_matrix()[non_missing] == |
2276 | | - truncated_ts.genotype_matrix()[non_missing])) |
| 2212 | + self.assertTrue(np.all(expected == ts.genotype_matrix())) |
0 commit comments