Skip to content

Commit 7835114

Browse files
committed
Allow sample missing data
1 parent d2300a8 commit 7835114

File tree

6 files changed

+391
-91
lines changed

6 files changed

+391
-91
lines changed

tests/test_formats.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -241,47 +241,59 @@ def test_provenance(self):
241241

242242
def test_variant_errors(self):
243243
input_file = formats.SampleData(sequence_length=10)
244-
genotypes = np.zeros(2, np.int8)
244+
genotypes = [0, 0]
245245
input_file.add_site(0, alleles=["0", "1"], genotypes=genotypes)
246+
self.assertRaises(
247+
ValueError, input_file.add_site, position=1,
248+
alleles=["0", "1", "2"], genotypes=genotypes)
246249
for bad_position in [-1, 10, 100]:
247250
self.assertRaises(
248251
ValueError, input_file.add_site, position=bad_position,
249252
alleles=["0", "1"], genotypes=genotypes)
250-
for bad_genotypes in [[0, 2], [-1, 0], [], [0], [0, 0, 0]]:
251-
genotypes = np.array(bad_genotypes, dtype=np.int8)
253+
for bad_genotypes in [[0, 2], [-2, 0], [], [0], [0, 0, 0]]:
252254
self.assertRaises(
253255
ValueError, input_file.add_site, position=1,
254-
alleles=["0", "1"], genotypes=genotypes)
256+
alleles=["0", "1"], genotypes=bad_genotypes)
255257
self.assertRaises(
256258
ValueError, input_file.add_site, position=1,
257-
alleles=["0", "1", "2"], genotypes=np.zeros(2, dtype=np.int8))
259+
alleles=["0"], genotypes=[0, 1])
258260
self.assertRaises(
259261
ValueError, input_file.add_site, position=1,
260-
alleles=["0"], genotypes=np.array([0, 1], dtype=np.int8))
262+
alleles=["0", "1"], genotypes=[0, 2])
261263
self.assertRaises(
262264
ValueError, input_file.add_site, position=1,
263-
alleles=["0", "1"], genotypes=np.array([0, 2], dtype=np.int8))
264-
self.assertRaises(
265-
ValueError, input_file.add_site, position=1,
266-
alleles=["0", "0"], genotypes=np.array([0, 2], dtype=np.int8))
265+
alleles=["0", "0"], genotypes=[0, 2])
267266

268267
def test_invalid_inference_sites(self):
269268
# Trying to add singletons or fixed sites as inference sites
270269
# raise and error
271270
input_file = formats.SampleData()
272271
# Make sure this is OK
273-
input_file.add_site(0, [0, 1, 1], inference=True)
272+
input_file.add_site(0, [0, 1, 1, tskit.MISSING_DATA], inference=True)
273+
self.assertRaises(
274+
ValueError, input_file.add_site,
275+
position=1, genotypes=[0, 0, 0, 0], inference=True)
274276
self.assertRaises(
275277
ValueError, input_file.add_site,
276-
position=1, genotypes=[0, 0, 0], inference=True)
278+
position=1, genotypes=[1, 0, 0, 0], inference=True)
277279
self.assertRaises(
278280
ValueError, input_file.add_site,
279-
position=1, genotypes=[1, 0, 0], inference=True)
281+
position=1, genotypes=[1, 1, 1, 1], inference=True)
280282
self.assertRaises(
281283
ValueError, input_file.add_site,
282-
position=1, genotypes=[1, 1, 1], inference=True)
284+
position=1, genotypes=[tskit.MISSING_DATA, 0, 0, 0], inference=True)
285+
self.assertRaises(
286+
ValueError, input_file.add_site,
287+
position=1, genotypes=[tskit.MISSING_DATA, 1, 1, 1], inference=True)
288+
self.assertRaises(
289+
ValueError, input_file.add_site,
290+
position=1, genotypes=[tskit.MISSING_DATA, 0, 1, 0], inference=True)
291+
self.assertRaises(
292+
ValueError, input_file.add_site,
293+
position=1, genotypes=[tskit.MISSING_DATA] * 4, inference=True)
294+
# Check we can still add at pos 1
283295
input_file.add_site(
284-
position=1, genotypes=[1, 0, 1], inference=True)
296+
position=1, genotypes=[1, 0, 1, tskit.MISSING_DATA], inference=True)
285297

286298
def test_duplicate_sites(self):
287299
# Duplicate sites are not accepted.
@@ -770,6 +782,30 @@ def test_sequence_length(self):
770782
data.finalise()
771783
self.assertEqual(data.sequence_length, 1)
772784

785+
def test_missing_data(self):
786+
u = tskit.MISSING_DATA
787+
sites_by_samples = np.array([
788+
[u, u, u, 1, 1, 0, 1, 1, 1],
789+
[u, u, u, 1, 1, 0, 1, 1, 0],
790+
[u, u, u, 1, 0, 1, 1, 0, 1],
791+
[u, 0, 0, 1, 1, 1, 1, u, u],
792+
[u, 0, 1, 1, 1, 0, 1, u, u],
793+
[u, 1, 1, 0, 0, 0, 0, u, u]
794+
], dtype=np.int8)
795+
with tsinfer.SampleData() as data:
796+
for col in range(sites_by_samples.shape[1]):
797+
data.add_site(col, sites_by_samples[:, col])
798+
799+
self.assertEqual(data.sequence_length, 9.0)
800+
self.assertEqual(data.num_sites, 9)
801+
# First site is a entirely missing, second is singleton with missing data =>
802+
# neither should be marked for inference
803+
inference_sites = data.sites_inference[:]
804+
self.assertEqual(inference_sites[0], 0) # Entirely missing data
805+
self.assertEqual(inference_sites[1], 0) # Singleton with missing data
806+
for i in inference_sites[2:]:
807+
self.assertEqual(i, 1)
808+
773809

774810
class TestAncestorData(unittest.TestCase, DataContainerMixin):
775811
"""

tests/test_inference.py

Lines changed: 132 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -615,8 +615,8 @@ def verify_inserted_ancestors(self, ts):
615615
ancestor_data.finalise()
616616

617617
A = np.full(
618-
(ancestor_data.num_sites, ancestor_data.num_ancestors),
619-
tskit.MISSING_DATA, dtype=np.int8)
618+
(ancestor_data.num_sites, ancestor_data.num_ancestors), tskit.MISSING_DATA,
619+
dtype=np.int8)
620620
start = ancestor_data.ancestors_start[:]
621621
end = ancestor_data.ancestors_end[:]
622622
ancestors = ancestor_data.ancestors_haplotype[:]
@@ -1649,10 +1649,12 @@ class TestExtractAncestors(unittest.TestCase):
16491649
"""
16501650
def verify(self, samples):
16511651
ancestors = tsinfer.generate_ancestors(samples)
1652+
# this ancestors TS has positions mapped only to inference sites
16521653
ancestors_ts_1 = tsinfer.match_ancestors(samples, ancestors)
16531654
ts = tsinfer.match_samples(
16541655
samples, ancestors_ts_1, path_compression=False, simplify=False)
16551656
t1 = ancestors_ts_1.dump_tables()
1657+
16561658
t2, node_id_map = tsinfer.extract_ancestors(samples, ts)
16571659
self.assertEqual(len(t2.provenances), len(t1.provenances) + 2)
16581660
t1.provenances.clear()
@@ -1662,14 +1664,6 @@ def verify(self, samples):
16621664
# for now.
16631665
t2.populations.clear()
16641666

1665-
self.assertEqual(t1.nodes, t2.nodes)
1666-
self.assertEqual(t1.edges, t2.edges)
1667-
self.assertEqual(t1.sites, t2.sites)
1668-
self.assertEqual(t1.mutations, t2.mutations)
1669-
self.assertEqual(t1.populations, t2.populations)
1670-
self.assertEqual(t1.individuals, t2.individuals)
1671-
self.assertEqual(t1.sites, t2.sites)
1672-
16731667
self.assertEqual(t1, t2)
16741668

16751669
for node in ts.nodes():
@@ -1683,8 +1677,7 @@ def test_simple_simulation(self):
16831677
def test_non_zero_one_mutations(self):
16841678
ts = msprime.simulate(10, recombination_rate=5, random_seed=2)
16851679
ts = msprime.mutate(
1686-
ts, rate=5, model=msprime.InfiniteSites(msprime.NUCLEOTIDES),
1687-
random_seed=15)
1680+
ts, rate=2, model=msprime.InfiniteSites(msprime.NUCLEOTIDES), random_seed=15)
16881681
self.assertGreater(ts.num_mutations, 0)
16891682
self.verify(tsinfer.SampleData.from_tree_sequence(ts, use_times=False))
16901683

@@ -1931,3 +1924,130 @@ def verify_example(self, full_subset, samples, ancestors, path_compression):
19311924
self.assertEqual(expected_sample_ancestors, num_sample_ancestors)
19321925
tsinfer.verify(samples, final_ts.simplify())
19331926
ancestors_ts = augmented_ancestors
1927+
1928+
1929+
class TestMissingSampleDataInference(unittest.TestCase):
1930+
"""
1931+
Test that we can infer sites with tskit.MISSING_DATA, using both the PY and C engines
1932+
"""
1933+
def test_missing_haplotypes(self):
1934+
u = tskit.MISSING_DATA
1935+
sites_by_samples = np.array([
1936+
[u, u, u, u],
1937+
[u, 1, 0, u],
1938+
[u, 0, 1, u],
1939+
[u, 1, 1, u]
1940+
], dtype=np.int8)
1941+
with tsinfer.SampleData() as sample_data:
1942+
for col in range(sites_by_samples.shape[1]):
1943+
sample_data.add_site(col, sites_by_samples[:, col])
1944+
ts = tsinfer.infer(sample_data)
1945+
self.assertTrue(np.all(sites_by_samples == ts.genotype_matrix().T))
1946+
1947+
def test_samples_missing_inference_sites(self):
1948+
u = tskit.MISSING_DATA
1949+
sites_by_samples = np.array([
1950+
[1, 0, 0, u],
1951+
[1, 0, 0, u],
1952+
[0, 1, 1, 1],
1953+
[u, u, u, 1]], dtype=np.int8)
1954+
with tsinfer.SampleData() as sample_data:
1955+
for col in range(sites_by_samples.shape[1]):
1956+
sample_data.add_site(col, sites_by_samples[:, col])
1957+
ts = tsinfer.infer(sample_data, simplify=False)
1958+
self.assertTrue(np.all(sites_by_samples == ts.genotype_matrix().T))
1959+
1960+
def test_small_truncated_fragments(self):
1961+
u = tskit.MISSING_DATA
1962+
sites_by_samples = np.array([
1963+
[u, u, u, 1, 1, 0, 1, 1, 1, u],
1964+
[u, u, u, 1, 0, 0, 1, 1, 0, u],
1965+
[u, u, u, 1, 0, 1, 1, 0, 1, u],
1966+
[u, 0, 0, 1, 0, 1, 1, u, u, u],
1967+
[u, 0, 1, 1, 0, 0, 1, u, u, u],
1968+
[u, 1, 1, 0, 0, 0, 0, u, u, u]
1969+
], dtype=np.int8)
1970+
with tsinfer.SampleData() as sample_data:
1971+
for col in range(sites_by_samples.shape[1]):
1972+
sample_data.add_site(col, sites_by_samples[:, col])
1973+
for e in [tsinfer.PY_ENGINE, tsinfer.C_ENGINE]:
1974+
ancestors = tsinfer.generate_ancestors(sample_data, engine=e)
1975+
ancestors_ts = tsinfer.match_ancestors(
1976+
sample_data, ancestors, engine=e, extended_checks=True)
1977+
ts = tsinfer.match_samples(
1978+
sample_data, ancestors_ts, engine=e, extended_checks=True)
1979+
self.assertTrue(1.0 in ts.breakpoints(True)) # End of lft unknown region
1980+
self.assertTrue(3.0 in ts.breakpoints(True)) # End of 1st unknown batch
1981+
self.assertTrue(7.0 in ts.breakpoints(True)) # Start of 2nd unknown batch
1982+
self.assertTrue(9.0 in ts.breakpoints(True)) # Start of rgt unknown region
1983+
for tree in ts.trees():
1984+
for s in ts.samples():
1985+
if tree.interval[1] <= 1:
1986+
self.assertTrue(tree.parent(s) == tskit.NULL)
1987+
elif tree.interval[1] <= 3:
1988+
if s in [0, 1, 2]:
1989+
self.assertTrue(tree.parent(s) == tskit.NULL)
1990+
else:
1991+
self.assertTrue(tree.parent(s) != tskit.NULL)
1992+
elif tree.interval[0] >= 9:
1993+
self.assertTrue(tree.parent(s) == tskit.NULL)
1994+
elif tree.interval[0] >= 7:
1995+
if s in [3, 4, 5]:
1996+
self.assertTrue(tree.parent(s) == tskit.NULL)
1997+
else:
1998+
self.assertTrue(tree.parent(s) != tskit.NULL)
1999+
2000+
self.assertTrue(np.all(sites_by_samples == ts.genotype_matrix().T))
2001+
2002+
def test_large_truncated_fragments(self):
2003+
"""
2004+
A bit like fragments produced from a sequencer
2005+
"""
2006+
def truncate_ts_samples(ts, average_span, random_seed, min_span=5):
2007+
"""
2008+
Create a tree sequence that has sample nodes which have been truncated
2009+
so that they span only a small region of the genome. The length of the
2010+
truncated spans is given by a poisson distribution whose mean is average_span
2011+
but which cannot go below a fixed min_span, or above the sequence_length
2012+
2013+
Samples are truncated by removing the edges that connect them to the rest
2014+
of the tree.
2015+
"""
2016+
np.random.seed(random_seed)
2017+
# Make a list of (left,right) tuples giving the new limits of each sample
2018+
# Keyed by sample ID.
2019+
keep = {}
2020+
# for simplicity, we pick lengths from a poisson distribution of av 300 bp
2021+
for sample_id, span in zip(
2022+
ts.samples(), np.random.poisson(average_span, ts.num_samples)):
2023+
span = max(span, min_span)
2024+
span = min(span, ts.sequence_length)
2025+
start = np.random.uniform(0, ts.sequence_length-span)
2026+
keep[sample_id] = (start, start+span)
2027+
2028+
tables = ts.dump_tables()
2029+
tables.edges.clear()
2030+
for e in ts.tables.edges:
2031+
if e.child not in keep:
2032+
left, right = e.left, e.right
2033+
else:
2034+
if e.right <= keep[e.child][0] or e.left >= keep[e.child][1]:
2035+
continue # this edge is outside the focal region
2036+
else:
2037+
left = max(e.left, keep[e.child][0])
2038+
right = min(e.right, keep[e.child][1])
2039+
tables.edges.add_row(left, right, e.parent, e.child)
2040+
return tables.tree_sequence()
2041+
2042+
ts = msprime.simulate(
2043+
100, Ne=1e2, length=400, recombination_rate=1e-4, mutation_rate=2e-4,
2044+
random_seed=1)
2045+
truncated_ts = truncate_ts_samples(ts, average_span=200, random_seed=123)
2046+
sd = tsinfer.SampleData.from_tree_sequence(truncated_ts, use_times=False)
2047+
# Cannot use the normal `simplify` as this removes parts of the TS where only
2048+
# one sample is connected to the root (& the other samples have missing data)
2049+
ts_inferred = tsinfer.infer(sd, simplify=False)
2050+
# Instead we run simplicy explicitly, with `keep_unary=True`
2051+
ts_inferred = ts_inferred.simplify(filter_sites=False, keep_unary=True)
2052+
self.assertTrue(
2053+
np.all(ts_inferred.genotype_matrix() == truncated_ts.genotype_matrix()))

tsinfer/algorithm.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,14 @@ def ancestor_descriptors(self):
140140

141141
def compute_ancestral_states(self, a, focal_site, sites):
142142
"""
143-
Together with make_ancestor, this is the main algorithm as implemented in Fig S2
144-
of the preprint, with the buffer.
143+
For a given focal site, and set of sites to fill in (usually all the ones
144+
leftwards or rightwards), augment the haplotype array a with the inferred sites
145+
Together with `make_ancestor`, which calls this function, these describe the main
146+
algorithm as implemented in Fig S2 of the preprint, with the buffer.
147+
148+
TODO - account for tskit.MISSING_DATA in samples (e.g. when encountered in
149+
the remove_buffer we should keep the sample in the buffer until we know that
150+
there is a conflict, rather than clear the remove buffer on every iteration)
145151
"""
146152
focal_age = self.sites[focal_site].age
147153
S = set(np.where(self.sites[focal_site].genotypes == 1)[0])
@@ -154,8 +160,8 @@ def compute_ancestral_states(self, a, focal_site, sites):
154160
last_site = l
155161
if self.sites[l].age > focal_age:
156162
g_l = self.sites[l].genotypes
157-
ones = sum(g_l[u] for u in S)
158-
zeros = len(S) - ones
163+
ones = sum(g_l[u] == 1 for u in S)
164+
zeros = sum(g_l[u] == 0 for u in S)
159165
# print("\tsite", l, ones, zeros, sep="\t")
160166
consensus = 0
161167
if ones >= zeros:
@@ -174,6 +180,7 @@ def compute_ancestral_states(self, a, focal_site, sites):
174180
if g_l[u] != consensus:
175181
remove_buffer.append(u)
176182
a[l] = consensus
183+
assert a[last_site] != tskit.MISSING_DATA
177184
return last_site
178185

179186
def make_ancestor(self, focal_sites, a):
@@ -189,27 +196,34 @@ def make_ancestor(self, focal_sites, a):
189196
for focal_site in focal_sites:
190197
a[focal_site] = 1
191198
S = set(np.where(self.sites[focal_sites[0]].genotypes == 1)[0])
199+
# Interpolate ancestral haplotype within focal region (i.e. region
200+
# spanning from leftmost to rightmost focal site)
192201
for j in range(len(focal_sites) - 1):
202+
# Interpolate region between focal site j and focal site j+1
193203
for l in range(focal_sites[j] + 1, focal_sites[j + 1]):
194204
a[l] = 0
195205
if self.sites[l].age > focal_age:
196206
g_l = self.sites[l].genotypes
197-
ones = sum(g_l[u] for u in S)
198-
zeros = len(S) - ones
207+
ones = sum(g_l[u] == 1 for u in S)
208+
zeros = sum(g_l[u] == 0 for u in S)
199209
# print("\t", l, ones, zeros, sep="\t")
200-
if ones >= zeros:
210+
if ones >= zeros: # Should probably be "ones > zeros" (see below)
211+
# Since this site should be older, this is a conflict
212+
# We just take the majority rule. If equal, we assume that
213+
# the derived variant is more likely (this is probably wrong)
214+
# (we could possibly do something more sophisticated for ancient
215+
# samples by taking into account the sample age)
201216
a[l] = 1
202-
# Go rightwards
217+
# Extend ancestral haplotype rightwards from rightmost focal site
203218
focal_site = focal_sites[-1]
204219
last_site = self.compute_ancestral_states(
205220
a, focal_site, range(focal_site + 1, self.num_sites))
206221
assert a[last_site] != tskit.MISSING_DATA
207222
end = last_site + 1
208-
# Go leftwards
223+
# Extend ancestral haplotype leftwards from leftmost focal site
209224
focal_site = focal_sites[0]
210225
last_site = self.compute_ancestral_states(
211226
a, focal_site, range(focal_site - 1, -1, -1))
212-
assert a[last_site] != tskit.MISSING_DATA
213227
start = last_site
214228
return start, end
215229

@@ -386,7 +400,7 @@ def update_node_time(self, child_id, pc_parent_id):
386400
edge = edge.next
387401
assert min_parent_time >= 0
388402
assert min_parent_time <= self.time[0]
389-
# For the asserttion to be violated we would need to have 64K pc
403+
# For the assertion to be violated we would need to have 64K pc
390404
# ancestors sequentially copying from each other.
391405
self.time[pc_parent_id] = min_parent_time - (1 / 2**16)
392406
assert self.time[pc_parent_id] > self.time[child_id]
@@ -553,6 +567,10 @@ def dump_nodes(self):
553567
return flags, time
554568

555569
def dump_edges(self):
570+
"""
571+
Return all the edges, in path order (such that all edges for a child are gathered
572+
together, and the edges for this child are always listed from left to right)
573+
"""
556574
left = np.zeros(self.num_edges, dtype=np.int32)
557575
right = np.zeros(self.num_edges, dtype=np.int32)
558576
parent = np.zeros(self.num_edges, dtype=np.int32)

0 commit comments

Comments
 (0)