@@ -753,7 +753,8 @@ def verify(self, genotypes):
753
753
exclude_positions = exclude_positions ,
754
754
)
755
755
for tree in output_ts .trees ():
756
- assert tree .num_roots == 1
756
+ if tree .num_edges > 0 or 0 < tree .index < output_ts .num_trees - 1 :
757
+ assert tree .num_roots == 1
757
758
for site in output_ts .sites ():
758
759
inf_type = json .loads (site .metadata )["inference_type" ]
759
760
if len (site .mutations ) == 0 :
@@ -2682,6 +2683,23 @@ def test_single_tree(self, small_ts_fixture):
2682
2683
def test_many_trees (self , medium_ts_fixture ):
2683
2684
self .verify (medium_ts_fixture )
2684
2685
2686
+ def test_flanking_regions_deleted (self , small_sd_fixture ):
2687
+ ts1 = tsinfer .infer (small_sd_fixture )
2688
+ assert ts1 .site (- 1 ).position + 1 < ts1 .sequence_length
2689
+ assert ts1 .first ().num_edges == 0
2690
+ assert ts1 .last ().num_edges == 0
2691
+ assert ts1 .first ().interval .right == small_sd_fixture .sites_position [0 ]
2692
+ assert ts1 .last ().interval .left == small_sd_fixture .sites_position [- 1 ] + 1
2693
+
2694
+ # If seq length is less than the last pos + 1, right flank is not deleted
2695
+ sd = small_sd_fixture .subset (sequence_length = ts1 .site (- 1 ).position + 0.1 )
2696
+ ts2 = tsinfer .infer (sd )
2697
+ assert ts2 .first ().num_edges == 0
2698
+ assert ts2 .last ().num_edges != 0
2699
+ assert ts2 .first ().interval .right == sd .sites_position [0 ]
2700
+
2701
+ assert ts2 .num_trees == ts1 .num_trees - 1
2702
+
2685
2703
def test_standalone_post_process (self , medium_sd_fixture ):
2686
2704
# test we can post process separately, e.g. omitting the MRCA splitting step
2687
2705
ts_unsimplified = tsinfer .infer (medium_sd_fixture , post_process = False )
@@ -2700,7 +2718,7 @@ def test_standalone_post_process(self, medium_sd_fixture):
2700
2718
md = json .loads (md .decode ()) # At the moment node metadata has no schema
2701
2719
assert md ["ancestor_data_id" ] == 1
2702
2720
2703
- ts = tsinfer .post_process (ts_unsimplified , split_mrca = True )
2721
+ ts = tsinfer .post_process (ts_unsimplified , split_mrca = True , erase_flanks = False )
2704
2722
oldest_parent_id = ts .edge (- 1 ).parent
2705
2723
assert np .sum (ts .tables .nodes .time == ts .node (oldest_parent_id ).time ) > 1
2706
2724
roots = set ()
@@ -2719,11 +2737,13 @@ def test_post_process_non_tsinfer(self, small_ts_fixture, caplog):
2719
2737
small_ts_fixture .samples () == np .arange (small_ts_fixture .num_samples )
2720
2738
)
2721
2739
with caplog .at_level (logging .WARNING ):
2722
- ts_postprocessed = tsinfer .post_process (small_ts_fixture )
2740
+ ts_postprocessed = tsinfer .post_process (
2741
+ small_ts_fixture , erase_flanks = False
2742
+ )
2723
2743
assert caplog .text .count ("virtual-root-like" ) == 0
2724
2744
with caplog .at_level (logging .WARNING ):
2725
2745
ts_postprocessed = tsinfer .post_process (
2726
- small_ts_fixture , warn_if_unexpected_format = True
2746
+ small_ts_fixture , warn_if_unexpected_format = True , erase_flanks = False
2727
2747
)
2728
2748
assert caplog .text .count ("virtual-root-like" ) == 1
2729
2749
@@ -2747,8 +2767,8 @@ def test_virtual_like_root_removed(self, medium_sd_fixture):
2747
2767
2748
2768
def test_split_edges_one_tree (self , small_sd_fixture ):
2749
2769
ts = tsinfer .infer (small_sd_fixture , post_process = False )
2750
- ts = tsinfer .post_process (ts , split_mrca = False )
2751
2770
assert ts .num_trees == 1
2771
+ ts = tsinfer .post_process (ts , split_mrca = False )
2752
2772
# Check that we don't delete and recreate the oldest node if there's only 1 tree
2753
2773
tables = ts .dump_tables ()
2754
2774
oldest_node_in_topology = tables .edges [- 1 ].parent
@@ -2758,7 +2778,7 @@ def test_split_edges_one_tree(self, small_sd_fixture):
2758
2778
2759
2779
def test_dont_split_edges_twice (self , medium_sd_fixture , caplog ):
2760
2780
ts = tsinfer .infer (medium_sd_fixture , post_process = False )
2761
- ts = tsinfer .post_process (ts , split_mrca = False )
2781
+ ts = tsinfer .post_process (ts , split_mrca = False , erase_flanks = False )
2762
2782
assert ts .num_trees > 1
2763
2783
assert tsinfer .has_same_root_everywhere (ts )
2764
2784
# Once the mrca has been split, it can't be split again
@@ -2801,6 +2821,18 @@ def test_sample_order(self, medium_sd_fixture):
2801
2821
inferred_individual = ts .individual (ts .node (inferred_node_id ).individual )
2802
2822
assert sd_individual .metadata ["id" ] == inferred_individual .metadata ["id" ]
2803
2823
2824
+ def test_erase_flanks (self , small_sd_fixture ):
2825
+ ts1 = tsinfer .infer (small_sd_fixture , post_process = False )
2826
+ ts2 = tsinfer .post_process (ts1 , erase_flanks = False )
2827
+ assert ts2 .first ().num_edges > 0
2828
+ assert ts2 .last ().num_edges > 0
2829
+ assert ts1 .num_trees == ts2 .num_trees
2830
+
2831
+ ts2 = tsinfer .post_process (ts1 , erase_flanks = True )
2832
+ assert ts2 .first ().num_edges == 0
2833
+ assert ts2 .last ().num_edges == 0
2834
+ assert ts1 .num_trees == ts2 .num_trees - 2
2835
+
2804
2836
2805
2837
def get_default_inference_sites (sample_data ):
2806
2838
"""
@@ -3990,7 +4022,8 @@ def test_nan_sites(self):
3990
4022
sample_data .add_site (0.4 , [1 , 1 , 0 ], time = np .nan )
3991
4023
sample_data .add_site (0.6 , [1 , 1 , 0 ])
3992
4024
ts = tsinfer .infer (sample_data )
3993
- assert ts .num_trees == 1
4025
+ num_nonempty_trees = sum (1 for tree in ts .trees () if tree .num_edges > 0 )
4026
+ assert num_nonempty_trees == 1
3994
4027
inf_type = [json .loads (site .metadata )["inference_type" ] for site in ts .sites ()]
3995
4028
assert inf_type [0 ] == tsinfer .INFERENCE_FULL
3996
4029
assert inf_type [1 ] == tsinfer .INFERENCE_PARSIMONY
0 commit comments