Skip to content

Commit c493191

Browse files
committed
All tests pass!
1 parent dd7262b commit c493191

File tree

3 files changed

+15
-24
lines changed

3 files changed

+15
-24
lines changed

tests/test_inference.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,9 +1656,6 @@ def verify(self, samples):
16561656
# for now.
16571657
t2.populations.clear()
16581658

1659-
for e1, e2 in zip(t1.edges, t2.edges):
1660-
if e1 != e2:
1661-
print(e1, e2)
16621659
self.assertEqual(t1, t2)
16631660

16641661
for node in ts.nodes():
@@ -1670,18 +1667,12 @@ def test_simple_simulation(self):
16701667
self.verify(tsinfer.SampleData.from_tree_sequence(ts))
16711668

16721669
def test_non_zero_one_mutations(self):
1673-
ts = msprime.simulate(100, length=100, recombination_rate=5, random_seed=2)
1670+
ts = msprime.simulate(10, recombination_rate=5, random_seed=2)
16741671
ts = msprime.mutate(
16751672
ts, rate=2, model=msprime.InfiniteSites(msprime.NUCLEOTIDES), random_seed=15)
16761673
self.assertGreater(ts.num_mutations, 0)
16771674
self.verify(tsinfer.SampleData.from_tree_sequence(ts, use_times=False))
16781675

1679-
def test_limited_length_ancestors(self):
1680-
ts = msprime.simulate(10, length=10, recombination_rate=5, random_seed=5)
1681-
ts = msprime.mutate(ts, rate=2, random_seed=15)
1682-
self.assertGreater(ts.num_mutations, 0)
1683-
self.verify(tsinfer.SampleData.from_tree_sequence(ts, use_times=False))
1684-
16851676
def test_random_data_small_examples(self):
16861677
np.random.seed(4)
16871678
num_random_tests = 10
@@ -1977,10 +1968,10 @@ def test_small_truncated_fragments(self):
19771968
sample_data, ancestors, engine=e, extended_checks=True)
19781969
ts = tsinfer.match_samples(
19791970
sample_data, ancestors_ts, engine=e, extended_checks=True)
1980-
self.assertTrue(1.0 in list(ts.breakpoints())) # End of lft unknown region
1981-
self.assertTrue(3.0 in list(ts.breakpoints())) # End of 1st unknown batch
1982-
self.assertTrue(7.0 in list(ts.breakpoints())) # Start of 2nd unknown batch
1983-
self.assertTrue(9.0 in list(ts.breakpoints())) # Start of rgt unknown region
1971+
self.assertTrue(1.0 in ts.breakpoints(True)) # End of lft unknown region
1972+
self.assertTrue(3.0 in ts.breakpoints(True)) # End of 1st unknown batch
1973+
self.assertTrue(7.0 in ts.breakpoints(True)) # Start of 2nd unknown batch
1974+
self.assertTrue(9.0 in ts.breakpoints(True)) # Start of rgt unknown region
19841975
for tree in ts.trees():
19851976
for s in ts.samples():
19861977
if tree.interval[1] <= 1:

tsinfer/eval_util.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -577,12 +577,14 @@ def extract_ancestors(samples, ts):
577577
# Any edge whose left-hand edge does not map to an inference site is one that has
578578
# been extended leftwards to cover potentially missing sites. To turn it back into
579579
# an ancestors TS, for these sites we go rightwards until the next inference site
580-
# edges = tables.edges
581-
# tables.edges.set_columns(
582-
# left=position[np.searchsorted(position, edges.left)],
583-
# right=edges.right,
584-
# parent=edges.parent,
585-
# child=edges.child)
580+
pos = position.copy()
581+
pos[0] = 0.0
582+
edges = tables.edges
583+
tables.edges.set_columns(
584+
left=pos[np.searchsorted(position, edges.left)],
585+
right=edges.right,
586+
parent=edges.parent,
587+
child=edges.child)
586588

587589
# We cannot have flags that are both samples and have other flags set,
588590
# so we need to unset all the sample flags for these.

tsinfer/inference.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -712,8 +712,6 @@ def insert_sites(self, tables):
712712
derived_state=variant.alleles[derived_state[inferred_site]])
713713
inferred_site += 1
714714
else:
715-
self.locate_mutations_on_tree(tree, variant, tables.mutations)
716-
"""
717715
if np.all(variant.genotypes == tskit.MISSING_DATA):
718716
# Work around the fact that map_mutations has to have at least 1
719717
ancestral_state = tskit.MISSING_DATA
@@ -739,7 +737,6 @@ def insert_sites(self, tables):
739737
tables.mutations.add_row(
740738
site=site.id, node=mutation.node,
741739
derived_state=mutation.derived_state)
742-
"""
743740
progress_monitor.update()
744741
else:
745742
# Simple case where all sites are inference sites. We save a lot of time here
@@ -795,7 +792,8 @@ def get_augmented_ancestors_tree_sequence(self, sample_indexes):
795792

796793
position = tables.sites.position
797794
pos_map = np.hstack([position, [tables.sequence_length]])
798-
# pos_map[0] = 0
795+
pos_map[0] = 0
796+
# TODO - check this works with asumented ancestors with missing data
799797
left, right, parent, child = tsb.dump_edges()
800798
tables.edges.set_columns(
801799
left=pos_map[left], right=pos_map[right], parent=parent, child=child)

0 commit comments

Comments
 (0)