Skip to content

Commit 1a95e59

Browse files
committed
Delete topology in flanking regions
Fixes #483
1 parent e0d1053 commit 1a95e59

File tree

4 files changed

+77
-13
lines changed

4 files changed

+77
-13
lines changed

evaluation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1692,11 +1692,13 @@ def run_ancestor_quality(args):
16921692
def get_node_degree_by_depth(ts):
16931693
"""
16941694
Returns a tuple (degree, depth) for each node in each tree in the
1695-
specified tree sequence.
1695+
specified tree sequence (empty flanking regions are omitted)
16961696
"""
16971697
degree = []
16981698
depth = []
16991699
for tree in ts.trees():
1700+
if tree.num_edges == 0 and (tree.index == 0 or tree.index == ts.num_trees - 1):
1701+
continue
17001702
stack = [(tree.root, 0)]
17011703
while len(stack) > 0:
17021704
u, d = stack.pop()

tests/test_inference.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,8 @@ def verify(self, genotypes):
753753
exclude_positions=exclude_positions,
754754
)
755755
for tree in output_ts.trees():
756-
assert tree.num_roots == 1
756+
if tree.num_edges > 0 or 0 < tree.index < output_ts.num_trees - 1:
757+
assert tree.num_roots == 1
757758
for site in output_ts.sites():
758759
inf_type = json.loads(site.metadata)["inference_type"]
759760
if len(site.mutations) == 0:
@@ -2682,6 +2683,23 @@ def test_single_tree(self, small_ts_fixture):
26822683
def test_many_trees(self, medium_ts_fixture):
26832684
self.verify(medium_ts_fixture)
26842685

2686+
def test_flanking_regions_deleted(self, small_sd_fixture):
2687+
ts1 = tsinfer.infer(small_sd_fixture)
2688+
assert ts1.site(-1).position + 1 < ts1.sequence_length
2689+
assert ts1.first().num_edges == 0
2690+
assert ts1.last().num_edges == 0
2691+
assert ts1.first().interval.right == small_sd_fixture.sites_position[0]
2692+
assert ts1.last().interval.left == small_sd_fixture.sites_position[-1] + 1
2693+
2694+
# If seq length is less than the last pos + 1, right flank is not deleted
2695+
sd = small_sd_fixture.subset(sequence_length=ts1.site(-1).position + 0.1)
2696+
ts2 = tsinfer.infer(sd)
2697+
assert ts2.first().num_edges == 0
2698+
assert ts2.last().num_edges != 0
2699+
assert ts2.first().interval.right == sd.sites_position[0]
2700+
2701+
assert ts2.num_trees == ts1.num_trees - 1
2702+
26852703
def test_standalone_post_process(self, medium_sd_fixture):
26862704
# test we can post process separately, e.g. omitting the MRCA splitting step
26872705
ts_unsimplified = tsinfer.infer(medium_sd_fixture, post_process=False)
@@ -2700,7 +2718,7 @@ def test_standalone_post_process(self, medium_sd_fixture):
27002718
md = json.loads(md.decode()) # At the moment node metadata has no schema
27012719
assert md["ancestor_data_id"] == 1
27022720

2703-
ts = tsinfer.post_process(ts_unsimplified, split_mrca=True)
2721+
ts = tsinfer.post_process(ts_unsimplified, split_mrca=True, erase_flanks=False)
27042722
oldest_parent_id = ts.edge(-1).parent
27052723
assert np.sum(ts.tables.nodes.time == ts.node(oldest_parent_id).time) > 1
27062724
roots = set()
@@ -2719,11 +2737,13 @@ def test_post_process_non_tsinfer(self, small_ts_fixture, caplog):
27192737
small_ts_fixture.samples() == np.arange(small_ts_fixture.num_samples)
27202738
)
27212739
with caplog.at_level(logging.WARNING):
2722-
ts_postprocessed = tsinfer.post_process(small_ts_fixture)
2740+
ts_postprocessed = tsinfer.post_process(
2741+
small_ts_fixture, erase_flanks=False
2742+
)
27232743
assert caplog.text.count("virtual-root-like") == 0
27242744
with caplog.at_level(logging.WARNING):
27252745
ts_postprocessed = tsinfer.post_process(
2726-
small_ts_fixture, warn_if_unexpected_format=True
2746+
small_ts_fixture, warn_if_unexpected_format=True, erase_flanks=False
27272747
)
27282748
assert caplog.text.count("virtual-root-like") == 1
27292749

@@ -2747,8 +2767,8 @@ def test_virtual_like_root_removed(self, medium_sd_fixture):
27472767

27482768
def test_split_edges_one_tree(self, small_sd_fixture):
27492769
ts = tsinfer.infer(small_sd_fixture, post_process=False)
2750-
ts = tsinfer.post_process(ts, split_mrca=False)
27512770
assert ts.num_trees == 1
2771+
ts = tsinfer.post_process(ts, split_mrca=False)
27522772
# Check that we don't delete and recreate the oldest node if there's only 1 tree
27532773
tables = ts.dump_tables()
27542774
oldest_node_in_topology = tables.edges[-1].parent
@@ -2758,7 +2778,7 @@ def test_split_edges_one_tree(self, small_sd_fixture):
27582778

27592779
def test_dont_split_edges_twice(self, medium_sd_fixture, caplog):
27602780
ts = tsinfer.infer(medium_sd_fixture, post_process=False)
2761-
ts = tsinfer.post_process(ts, split_mrca=False)
2781+
ts = tsinfer.post_process(ts, split_mrca=False, erase_flanks=False)
27622782
assert ts.num_trees > 1
27632783
assert tsinfer.has_same_root_everywhere(ts)
27642784
# Once the mrca has been split, it can't be split again
@@ -2801,6 +2821,18 @@ def test_sample_order(self, medium_sd_fixture):
28012821
inferred_individual = ts.individual(ts.node(inferred_node_id).individual)
28022822
assert sd_individual.metadata["id"] == inferred_individual.metadata["id"]
28032823

2824+
def test_erase_flanks(self, small_sd_fixture):
2825+
ts1 = tsinfer.infer(small_sd_fixture, post_process=False)
2826+
ts2 = tsinfer.post_process(ts1, erase_flanks=False)
2827+
assert ts2.first().num_edges > 0
2828+
assert ts2.last().num_edges > 0
2829+
assert ts1.num_trees == ts2.num_trees
2830+
2831+
ts2 = tsinfer.post_process(ts1, erase_flanks=True)
2832+
assert ts2.first().num_edges == 0
2833+
assert ts2.last().num_edges == 0
2834+
assert ts1.num_trees == ts2.num_trees - 2
2835+
28042836

28052837
def get_default_inference_sites(sample_data):
28062838
"""
@@ -3990,7 +4022,8 @@ def test_nan_sites(self):
39904022
sample_data.add_site(0.4, [1, 1, 0], time=np.nan)
39914023
sample_data.add_site(0.6, [1, 1, 0])
39924024
ts = tsinfer.infer(sample_data)
3993-
assert ts.num_trees == 1
4025+
num_nonempty_trees = sum(1 for tree in ts.trees() if tree.num_edges > 0)
4026+
assert num_nonempty_trees == 1
39944027
inf_type = [json.loads(site.metadata)["inference_type"] for site in ts.sites()]
39954028
assert inf_type[0] == tsinfer.INFERENCE_FULL
39964029
assert inf_type[1] == tsinfer.INFERENCE_PARSIMONY

tsinfer/eval_util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,9 +776,14 @@ def node_span(ts):
776776
for u in [edge.parent, edge.child]:
777777
if start[u] == -1 and tree.num_samples(u) > 0:
778778
start[u] = left
779+
# add intervals for the last tree
779780
for u in tree.nodes():
780781
if tree.num_samples(u) > 0:
781782
S[u] += ts.sequence_length - start[u]
783+
# isolated sample nodes (e.g. where topology was deleted) will have been missed when
784+
# iterating over edges, but by definition they span the entire ts, so override them
785+
for u in ts.samples():
786+
S[u] = ts.sequence_length
782787
return S
783788

784789

tsinfer/inference.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,29 +1950,35 @@ def post_process(
19501950
ts,
19511951
*,
19521952
split_mrca=None,
1953+
erase_flanks=None,
19531954
# Parameters below deliberately undocumented
19541955
warn_if_unexpected_format=None,
19551956
simplify_only=None,
19561957
):
19571958
"""
1958-
post_process(ts, *, split_mrca=None)
1959+
post_process(ts, *, split_mrca=None, erase_flanks=None)
19591960
19601961
Post-process a tsinferred tree sequence into a more conventional form. This is
19611962
the function run by default on the final tree sequence output by
1962-
:func:`match_samples`. It involves the following 3 steps:
1963+
:func:`match_samples`. It involves the following 4 steps:
19631964
19641965
#. If the oldest node is connected to a single child via an edge that spans the
19651966
entire tree sequence, this oldest node is removed, so that its child becomes
19661967
the new root (this step is undertaken to remove the "virtual-root-like node"
19671968
which is added to ancestor tree sequences to enable matching).
19681969
#. If the oldest node is removed in the first step and the new root spans the
1969-
entire genome, it is treated as the "grand MRCA" and (if split_mrca is ``True``)
1970-
the node is split into multiple coexisiting nodes with the splits
1970+
entire genome, it is treated as the "grand MRCA" and (unless split_mrca is
1971+
``False``) the node is split into multiple coexisiting nodes with the splits
19711972
occurring whenever the children of the grand MRCA change. The rationale
19721973
is that tsinfer creates a grand MRCA consisting of a single ancestral haplotype
19731974
with all inference sites in the ancestral state: this is, however, unlikely
19741975
to represent a single ancestor in the past. If nodes in the tree sequence are
19751976
then dated, these MRCA nodes can be pushed to different times.
1977+
#. Often, extensive regions of genome exist before the first defined site and after
1978+
the last defined site. Since no data exists in these sections of the genome, post
1979+
processing by default erases the inferred topology in these regions. However,
1980+
if ``erase_flanks`` is False, the flanking regions at the start and end will be
1981+
assigned the same topology as inferred at the first and last site respectively.
19761982
#. The sample nodes are reordered such that they are the first nodes listed in the
19771983
node table, removing tree nodes and edges that are not on a path between the
19781984
root and any of the samples (by applying the :meth:`~tskit.TreeSequence.simplify`
@@ -1982,11 +1988,17 @@ def post_process(
19821988
:param bool split_mrca: If ``True`` (default) and the oldest node is the only
19831989
parent to a single "grand MRCA", split the grand MRCA into separate nodes
19841990
(see above). If ``False`` do not attempt to identify or split a grand MRCA.
1991+
:param bool erase_flanks: If ``True`` (default), keep only the
1992+
inferred topology between the first and last sites. If ``False``,
1993+
output regions of topology inferred before the first site and after
1994+
the last site.
19851995
:return: The post-processed tree sequence.
19861996
:rtype: tskit.TreeSequence
19871997
"""
19881998
if split_mrca is None:
19891999
split_mrca = True
2000+
if erase_flanks is None:
2001+
erase_flanks = True
19902002

19912003
tables = ts.dump_tables()
19922004

@@ -2011,10 +2023,21 @@ def post_process(
20112023
"Cannot find a virtual-root-like ancestor during preprocessing"
20122024
)
20132025

2026+
if erase_flanks and ts.num_sites > 0:
2027+
logger.info("Removing topology in flanking regions with keep_intervals")
2028+
# So that the last site falls within a tree, we must add one to the
2029+
# site position (or simply extend to the end of the ts)
2030+
upper_cutoff = min(ts.sites_position[-1] + 1, ts.sequence_length)
2031+
tables.keep_intervals(
2032+
[[ts.sites_position[0], upper_cutoff]],
2033+
simplify=False,
2034+
record_provenance=False,
2035+
)
2036+
20142037
logger.info(
20152038
"Simplifying with filter_sites=False, filter_populations=False, "
20162039
"filter_individuals=False, and keep_unary=True on "
2017-
f"{tables.nodes.num_rows} nodes and {ts.num_edges} edges"
2040+
f"{tables.nodes.num_rows} nodes and {tables.edges.num_rows} edges"
20182041
)
20192042
# NB: if this is an inferred TS, match_samples is guaranteed to produce samples
20202043
# in the same order as passed in to sample_indexes, and simplification will
@@ -2024,6 +2047,7 @@ def post_process(
20242047
filter_populations=False,
20252048
filter_individuals=False,
20262049
keep_unary=True,
2050+
record_provenance=False,
20272051
)
20282052
logger.info(
20292053
"Finished simplify; now have {} nodes and {} edges".format(

0 commit comments

Comments
 (0)