Skip to content

Commit 9c26fc0

Browse files
committed
Allow sample missing data
1 parent 87fdc66 commit 9c26fc0

File tree

5 files changed

+349
-70
lines changed

5 files changed

+349
-70
lines changed

tests/test_formats.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -241,47 +241,59 @@ def test_provenance(self):
241241

242242
def test_variant_errors(self):
243243
input_file = formats.SampleData(sequence_length=10)
244-
genotypes = np.zeros(2, np.int8)
244+
genotypes = [0, 0]
245245
input_file.add_site(0, alleles=["0", "1"], genotypes=genotypes)
246+
self.assertRaises(
247+
ValueError, input_file.add_site, position=1,
248+
alleles=["0", "1", "2"], genotypes=genotypes)
246249
for bad_position in [-1, 10, 100]:
247250
self.assertRaises(
248251
ValueError, input_file.add_site, position=bad_position,
249252
alleles=["0", "1"], genotypes=genotypes)
250-
for bad_genotypes in [[0, 2], [-1, 0], [], [0], [0, 0, 0]]:
251-
genotypes = np.array(bad_genotypes, dtype=np.int8)
253+
for bad_genotypes in [[0, 2], [-2, 0], [], [0], [0, 0, 0]]:
252254
self.assertRaises(
253255
ValueError, input_file.add_site, position=1,
254-
alleles=["0", "1"], genotypes=genotypes)
256+
alleles=["0", "1"], genotypes=bad_genotypes)
255257
self.assertRaises(
256258
ValueError, input_file.add_site, position=1,
257-
alleles=["0", "1", "2"], genotypes=np.zeros(2, dtype=np.int8))
259+
alleles=["0"], genotypes=[0, 1])
258260
self.assertRaises(
259261
ValueError, input_file.add_site, position=1,
260-
alleles=["0"], genotypes=np.array([0, 1], dtype=np.int8))
262+
alleles=["0", "1"], genotypes=[0, 2])
261263
self.assertRaises(
262264
ValueError, input_file.add_site, position=1,
263-
alleles=["0", "1"], genotypes=np.array([0, 2], dtype=np.int8))
264-
self.assertRaises(
265-
ValueError, input_file.add_site, position=1,
266-
alleles=["0", "0"], genotypes=np.array([0, 2], dtype=np.int8))
265+
alleles=["0", "0"], genotypes=[0, 2])
267266

268267
def test_invalid_inference_sites(self):
269268
# Trying to add singletons or fixed sites as inference sites
270269
# raise and error
271270
input_file = formats.SampleData()
272271
# Make sure this is OK
273-
input_file.add_site(0, [0, 1, 1], inference=True)
272+
input_file.add_site(0, [0, 1, 1, tskit.MISSING_DATA], inference=True)
273+
self.assertRaises(
274+
ValueError, input_file.add_site,
275+
position=1, genotypes=[0, 0, 0, 0], inference=True)
274276
self.assertRaises(
275277
ValueError, input_file.add_site,
276-
position=1, genotypes=[0, 0, 0], inference=True)
278+
position=1, genotypes=[1, 0, 0, 0], inference=True)
277279
self.assertRaises(
278280
ValueError, input_file.add_site,
279-
position=1, genotypes=[1, 0, 0], inference=True)
281+
position=1, genotypes=[1, 1, 1, 1], inference=True)
280282
self.assertRaises(
281283
ValueError, input_file.add_site,
282-
position=1, genotypes=[1, 1, 1], inference=True)
284+
position=1, genotypes=[tskit.MISSING_DATA, 0, 0, 0], inference=True)
285+
self.assertRaises(
286+
ValueError, input_file.add_site,
287+
position=1, genotypes=[tskit.MISSING_DATA, 1, 1, 1], inference=True)
288+
self.assertRaises(
289+
ValueError, input_file.add_site,
290+
position=1, genotypes=[tskit.MISSING_DATA, 0, 1, 0], inference=True)
291+
self.assertRaises(
292+
ValueError, input_file.add_site,
293+
position=1, genotypes=[tskit.MISSING_DATA] * 4, inference=True)
294+
# Check we can still add at pos 1
283295
input_file.add_site(
284-
position=1, genotypes=[1, 0, 1], inference=True)
296+
position=1, genotypes=[1, 0, 1, tskit.MISSING_DATA], inference=True)
285297

286298
def test_duplicate_sites(self):
287299
# Duplicate sites are not accepted.
@@ -770,6 +782,30 @@ def test_sequence_length(self):
770782
data.finalise()
771783
self.assertEqual(data.sequence_length, 1)
772784

785+
def test_missing_data(self):
786+
u = tskit.MISSING_DATA
787+
sites_by_samples = np.array([
788+
[u, u, u, 1, 1, 0, 1, 1, 1],
789+
[u, u, u, 1, 1, 0, 1, 1, 0],
790+
[u, u, u, 1, 0, 1, 1, 0, 1],
791+
[u, 0, 0, 1, 1, 1, 1, u, u],
792+
[u, 0, 1, 1, 1, 0, 1, u, u],
793+
[u, 1, 1, 0, 0, 0, 0, u, u]
794+
], dtype=np.int8)
795+
with tsinfer.SampleData() as data:
796+
for col in range(sites_by_samples.shape[1]):
797+
data.add_site(col, sites_by_samples[:, col])
798+
799+
self.assertEqual(data.sequence_length, 9.0)
800+
self.assertEqual(data.num_sites, 9)
801+
# First site is a entirely missing, second is singleton with missing data =>
802+
# neither should be marked for inference
803+
inference_sites = data.sites_inference[:]
804+
self.assertEqual(inference_sites[0], 0) # Entirely missing data
805+
self.assertEqual(inference_sites[1], 0) # Singleton with missing data
806+
for i in inference_sites[2:]:
807+
self.assertEqual(i, 1)
808+
773809

774810
class TestAncestorData(unittest.TestCase, DataContainerMixin):
775811
"""

tests/test_inference.py

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,8 +615,9 @@ def verify_inserted_ancestors(self, ts):
615615
tsinfer.build_simulated_ancestors(sample_data, ancestor_data, ts)
616616
ancestor_data.finalise()
617617

618-
A = np.zeros(
619-
(ancestor_data.num_sites, ancestor_data.num_ancestors), dtype=np.uint8)
618+
A = np.full(
619+
(ancestor_data.num_sites, ancestor_data.num_ancestors), tskit.MISSING_DATA,
620+
dtype=np.int8)
620621
start = ancestor_data.ancestors_start[:]
621622
end = ancestor_data.ancestors_end[:]
622623
ancestors = ancestor_data.ancestors_haplotype[:]
@@ -1922,3 +1923,114 @@ def verify_example(self, full_subset, samples, ancestors, path_compression):
19221923
self.assertEqual(expected_sample_ancestors, num_sample_ancestors)
19231924
tsinfer.verify(samples, final_ts.simplify())
19241925
ancestors_ts = augmented_ancestors
1926+
1927+
1928+
class TestMissingSampleDataInference(unittest.TestCase):
1929+
"""
1930+
Test that we can infer sites with tskit.MISSING_DATA, using both the PY and C engines
1931+
"""
1932+
def test_missing_haplotypes(self):
1933+
u = tskit.MISSING_DATA
1934+
sites_by_samples = np.array([
1935+
[u, u, u, u],
1936+
[u, 1, 0, u],
1937+
[u, 0, 1, u],
1938+
[u, 1, 1, u]
1939+
], dtype=np.int8)
1940+
with tsinfer.SampleData() as sample_data:
1941+
for col in range(sites_by_samples.shape[1]):
1942+
sample_data.add_site(col, sites_by_samples[:, col])
1943+
ts = tsinfer.infer(sample_data)
1944+
self.assertTrue(np.all(sites_by_samples == ts.genotype_matrix().T))
1945+
1946+
def test_small_truncated_fragments(self):
1947+
u = tskit.MISSING_DATA
1948+
sites_by_samples = np.array([
1949+
[u, u, u, 1, 1, 0, 1, 1, 1, u],
1950+
[u, u, u, 1, 0, 0, 1, 1, 0, u],
1951+
[u, u, u, 1, 0, 1, 1, 0, 1, u],
1952+
[u, 0, 0, 1, 0, 1, 1, u, u, u],
1953+
[u, 0, 1, 1, 0, 0, 1, u, u, u],
1954+
[u, 1, 1, 0, 0, 0, 0, u, u, u]
1955+
], dtype=np.int8)
1956+
with tsinfer.SampleData() as sample_data:
1957+
for col in range(sites_by_samples.shape[1]):
1958+
sample_data.add_site(col, sites_by_samples[:, col])
1959+
for e in [tsinfer.PY_ENGINE, tsinfer.C_ENGINE]:
1960+
ancestors = tsinfer.generate_ancestors(sample_data, engine=e)
1961+
ancestors_ts = tsinfer.match_ancestors(
1962+
sample_data, ancestors, engine=e, extended_checks=True)
1963+
ts = tsinfer.match_samples(
1964+
sample_data, ancestors_ts, engine=e, extended_checks=True)
1965+
self.assertTrue(1.0 in list(ts.breakpoints())) # End of lft unknown region
1966+
self.assertTrue(3.0 in list(ts.breakpoints())) # End of 1st unknown batch
1967+
self.assertTrue(7.0 in list(ts.breakpoints())) # Start of 2nd unknown batch
1968+
self.assertTrue(9.0 in list(ts.breakpoints())) # Start of rgt unknown region
1969+
for tree in ts.trees():
1970+
for s in ts.samples():
1971+
if tree.interval[1] <= 1:
1972+
self.assertTrue(tree.parent(s) == tskit.NULL)
1973+
elif tree.interval[1] <= 3:
1974+
if s in [0, 1, 2]:
1975+
self.assertTrue(tree.parent(s) == tskit.NULL)
1976+
else:
1977+
self.assertTrue(tree.parent(s) != tskit.NULL)
1978+
elif tree.interval[0] >= 9:
1979+
self.assertTrue(tree.parent(s) == tskit.NULL)
1980+
elif tree.interval[0] >= 7:
1981+
if s in [3, 4, 5]:
1982+
self.assertTrue(tree.parent(s) == tskit.NULL)
1983+
else:
1984+
self.assertTrue(tree.parent(s) != tskit.NULL)
1985+
1986+
self.assertTrue(np.all(sites_by_samples == ts.genotype_matrix().T))
1987+
1988+
def test_large_truncated_fragments(self):
1989+
def truncate_ts_samples(ts, average_span, random_seed, min_span=5):
1990+
"""
1991+
Create a tree sequence that has sample nodes which have been truncated
1992+
so that they span only a small region of the genome. The length of the
1993+
truncated spans is given by a poisson distribution whose mean is average_span
1994+
but which cannot go below a fixed min_span, or above the sequence_length
1995+
1996+
Samples are truncated by removing the edges that connect them to the rest
1997+
of the tree.
1998+
"""
1999+
np.random.seed(random_seed)
2000+
# Make a list of (left,right) tuples giving the new limits of each sample
2001+
# Keyed by sample ID.
2002+
keep = {}
2003+
# for simplicity, we pick lengths from a poisson distribution of av 300 bp
2004+
for sample_id, span in zip(
2005+
ts.samples(), np.random.poisson(average_span, ts.num_samples)):
2006+
span = max(span, min_span)
2007+
span = min(span, ts.sequence_length)
2008+
start = np.random.uniform(0, ts.sequence_length-span)
2009+
keep[sample_id] = (start, start+span)
2010+
2011+
tables = ts.dump_tables()
2012+
tables.edges.clear()
2013+
for e in ts.tables.edges:
2014+
if e.child not in keep:
2015+
left, right = e.left, e.right
2016+
else:
2017+
if e.right <= keep[e.child][0] or e.left >= keep[e.child][1]:
2018+
continue # this edge is outside the focal region
2019+
else:
2020+
left = max(e.left, keep[e.child][0])
2021+
right = min(e.right, keep[e.child][1])
2022+
tables.edges.add_row(left, right, e.parent, e.child)
2023+
return tables.tree_sequence()
2024+
2025+
ts = msprime.simulate(
2026+
100, Ne=1e2, length=400, recombination_rate=1e-4, mutation_rate=2e-4,
2027+
random_seed=1)
2028+
truncated_ts = truncate_ts_samples(ts, average_span=200, random_seed=123)
2029+
sd = tsinfer.SampleData.from_tree_sequence(truncated_ts, use_times=False)
2030+
# Cannot use the normal `simplify` as this removes parts of the TS where only
2031+
# one sample is connected to the root (& the other samples have missing data)
2032+
ts_inferred = tsinfer.infer(sd, simplify=False)
2033+
# Instead we run simplicy explicitly, with `keep_unary=True`
2034+
ts_inferred = ts_inferred.simplify(filter_sites=False, keep_unary=True)
2035+
self.assertTrue(
2036+
np.all(ts_inferred.genotype_matrix() == truncated_ts.genotype_matrix()))

tsinfer/algorithm.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,14 @@ def ancestor_descriptors(self):
140140

141141
def compute_ancestral_states(self, a, focal_site, sites):
142142
"""
143-
Together with make_ancestor, this is the main algorithm as implemented in Fig S2
144-
of the preprint, with the buffer.
143+
For a given focal site, and set of sites to fill in (usually all the ones
144+
leftwards or rightwards), augment the haplotype array a with the inferred sites
145+
Together with `make_ancestor`, which calls this function, these describe the main
146+
algorithm as implemented in Fig S2 of the preprint, with the buffer.
147+
148+
TODO - account for tskit.MISSING_DATA in samples (e.g. when encountered in
149+
the remove_buffer we should keep the sample in the buffer until we know that
150+
there is a conflict, rather than clear the remove buffer on every iteration)
145151
"""
146152
focal_age = self.sites[focal_site].age
147153
S = set(np.where(self.sites[focal_site].genotypes == 1)[0])
@@ -154,8 +160,8 @@ def compute_ancestral_states(self, a, focal_site, sites):
154160
last_site = l
155161
if self.sites[l].age > focal_age:
156162
g_l = self.sites[l].genotypes
157-
ones = sum(g_l[u] for u in S)
158-
zeros = len(S) - ones
163+
ones = sum(g_l[u] == 1 for u in S)
164+
zeros = sum(g_l[u] == 0 for u in S)
159165
# print("\tsite", l, ones, zeros, sep="\t")
160166
consensus = 0
161167
if ones >= zeros:
@@ -174,6 +180,7 @@ def compute_ancestral_states(self, a, focal_site, sites):
174180
if g_l[u] != consensus:
175181
remove_buffer.append(u)
176182
a[l] = consensus
183+
assert a[last_site] != tskit.MISSING_DATA
177184
return last_site
178185

179186
def make_ancestor(self, focal_sites, a):
@@ -189,27 +196,34 @@ def make_ancestor(self, focal_sites, a):
189196
for focal_site in focal_sites:
190197
a[focal_site] = 1
191198
S = set(np.where(self.sites[focal_sites[0]].genotypes == 1)[0])
199+
# Interpolate ancestral haplotype within focal region (i.e. region
200+
# spanning from leftmost to rightmost focal site)
192201
for j in range(len(focal_sites) - 1):
202+
# Interpolate region between focal site j and focal site j+1
193203
for l in range(focal_sites[j] + 1, focal_sites[j + 1]):
194204
a[l] = 0
195205
if self.sites[l].age > focal_age:
196206
g_l = self.sites[l].genotypes
197-
ones = sum(g_l[u] for u in S)
198-
zeros = len(S) - ones
207+
ones = sum(g_l[u] == 1 for u in S)
208+
zeros = sum(g_l[u] == 0 for u in S)
199209
# print("\t", l, ones, zeros, sep="\t")
200-
if ones >= zeros:
210+
if ones >= zeros: # Should probably be "ones > zeros" (see below)
211+
# Since this site should be older, this is a conflict
212+
# We just take the majority rule. If equal, we assume that
213+
# the derived variant is more likely (this is probably wrong)
214+
# (we could possibly do something more sophisticated for ancient
215+
# samples by taking into account the sample age)
201216
a[l] = 1
202-
# Go rightwards
217+
# Extend ancestral haplotype rightwards from rightmost focal site
203218
focal_site = focal_sites[-1]
204219
last_site = self.compute_ancestral_states(
205220
a, focal_site, range(focal_site + 1, self.num_sites))
206221
assert a[last_site] != tskit.MISSING_DATA
207222
end = last_site + 1
208-
# Go leftwards
223+
# Extend ancestral haplotype leftwards from leftmost focal site
209224
focal_site = focal_sites[0]
210225
last_site = self.compute_ancestral_states(
211226
a, focal_site, range(focal_site - 1, -1, -1))
212-
assert a[last_site] != tskit.MISSING_DATA
213227
start = last_site
214228
return start, end
215229

@@ -386,7 +400,7 @@ def update_node_time(self, child_id, pc_parent_id):
386400
edge = edge.next
387401
assert min_parent_time >= 0
388402
assert min_parent_time <= self.time[0]
389-
# For the asserttion to be violated we would need to have 64K pc
403+
# For the assertion to be violated we would need to have 64K pc
390404
# ancestors sequentially copying from each other.
391405
self.time[pc_parent_id] = min_parent_time - (1 / 2**16)
392406
assert self.time[pc_parent_id] > self.time[child_id]
@@ -553,6 +567,10 @@ def dump_nodes(self):
553567
return flags, time
554568

555569
def dump_edges(self):
570+
"""
571+
Return all the edges, in path order (such that all edges for a child are gathered
572+
together, and the edges for this child are always listed from left to right)
573+
"""
556574
left = np.zeros(self.num_edges, dtype=np.int32)
557575
right = np.zeros(self.num_edges, dtype=np.int32)
558576
parent = np.zeros(self.num_edges, dtype=np.int32)

tsinfer/formats.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,12 +1121,17 @@ def add_site(
11211121

11221122
if alleles is None:
11231123
alleles = ["0", "1"]
1124-
if len(alleles) > 2:
1125-
raise ValueError("Only biallelic sites supported")
1124+
if len(set(alleles) - set([None])) > 2:
1125+
raise ValueError("Only biallelic sites supported: {}".format(alleles))
11261126
if len(set(alleles)) != len(alleles):
11271127
raise ValueError("Alleles must be distinct")
1128-
if np.any(genotypes >= len(alleles)) or np.any(genotypes < 0):
1129-
raise ValueError("Genotypes values must be between 0 and len(alleles) - 1")
1128+
# Check we can never confuse a real allele with the value for MISSING_DATA
1129+
assert not (0 <= tskit.MISSING_DATA <= len(alleles))
1130+
if np.any(np.logical_and(genotypes < 0, genotypes != tskit.MISSING_DATA)):
1131+
raise ValueError("Non-missing values for genotypes cannot be negative")
1132+
if np.any(np.logical_and(
1133+
genotypes >= len(alleles), genotypes != tskit.MISSING_DATA)):
1134+
raise ValueError("Non-missing values for genotypes must be < len(alleles)")
11301135
if genotypes.shape != (self.num_samples,):
11311136
raise ValueError("Must have num_samples genotypes.")
11321137
if position < 0:
@@ -1136,8 +1141,12 @@ def add_site(
11361141
if position <= self._last_position:
11371142
raise ValueError(
11381143
"Sites positions must be unique and added in increasing order")
1139-
count = np.sum(genotypes)
1140-
if count > 1 and count < self.num_samples:
1144+
1145+
n_known = np.sum(genotypes != tskit.MISSING_DATA)
1146+
n_unknown = self.num_samples - n_known
1147+
n_ancestral = np.sum(genotypes == 0)
1148+
n_derived = n_known - n_ancestral
1149+
if n_derived > 1 and n_derived < n_known:
11411150
if inference is None:
11421151
inference = True
11431152
else:
@@ -1147,7 +1156,8 @@ def add_site(
11471156
raise ValueError(
11481157
"Cannot specify singletons or fixed sites for inference")
11491158
if age is None:
1150-
age = count
1159+
age = n_derived
1160+
age += n_unknown/2.0 # Slight hack: unknown alleles create intermediate age
11511161
site_id = self._sites_writer.add(
11521162
position=position, genotypes=genotypes,
11531163
metadata=self._check_metadata(metadata),

0 commit comments

Comments
 (0)