@@ -615,8 +615,8 @@ def verify_inserted_ancestors(self, ts):
615615 ancestor_data .finalise ()
616616
617617 A = np .full (
618- (ancestor_data .num_sites , ancestor_data .num_ancestors ),
619- tskit . MISSING_DATA , dtype = np .int8 )
618+ (ancestor_data .num_sites , ancestor_data .num_ancestors ), tskit . MISSING_DATA ,
619+ dtype = np .int8 )
620620 start = ancestor_data .ancestors_start [:]
621621 end = ancestor_data .ancestors_end [:]
622622 ancestors = ancestor_data .ancestors_haplotype [:]
@@ -1641,10 +1641,12 @@ class TestExtractAncestors(unittest.TestCase):
16411641 """
16421642 def verify (self , samples ):
16431643 ancestors = tsinfer .generate_ancestors (samples )
1644+ # this ancestors TS has positions mapped only to inference sites
16441645 ancestors_ts_1 = tsinfer .match_ancestors (samples , ancestors )
16451646 ts = tsinfer .match_samples (
16461647 samples , ancestors_ts_1 , path_compression = False , simplify = False )
16471648 t1 = ancestors_ts_1 .dump_tables ()
1649+
16481650 t2 , node_id_map = tsinfer .extract_ancestors (samples , ts )
16491651 self .assertEqual (len (t2 .provenances ), len (t1 .provenances ) + 2 )
16501652 t1 .provenances .clear ()
@@ -1654,14 +1656,6 @@ def verify(self, samples):
16541656 # for now.
16551657 t2 .populations .clear ()
16561658
1657- self .assertEqual (t1 .nodes , t2 .nodes )
1658- self .assertEqual (t1 .edges , t2 .edges )
1659- self .assertEqual (t1 .sites , t2 .sites )
1660- self .assertEqual (t1 .mutations , t2 .mutations )
1661- self .assertEqual (t1 .populations , t2 .populations )
1662- self .assertEqual (t1 .individuals , t2 .individuals )
1663- self .assertEqual (t1 .sites , t2 .sites )
1664-
16651659 self .assertEqual (t1 , t2 )
16661660
16671661 for node in ts .nodes ():
@@ -1675,8 +1669,7 @@ def test_simple_simulation(self):
16751669 def test_non_zero_one_mutations (self ):
16761670 ts = msprime .simulate (10 , recombination_rate = 5 , random_seed = 2 )
16771671 ts = msprime .mutate (
1678- ts , rate = 5 , model = msprime .InfiniteSites (msprime .NUCLEOTIDES ),
1679- random_seed = 15 )
1672+ ts , rate = 2 , model = msprime .InfiniteSites (msprime .NUCLEOTIDES ), random_seed = 15 )
16801673 self .assertGreater (ts .num_mutations , 0 )
16811674 self .verify (tsinfer .SampleData .from_tree_sequence (ts , use_times = False ))
16821675
@@ -1923,3 +1916,130 @@ def verify_example(self, full_subset, samples, ancestors, path_compression):
19231916 self .assertEqual (expected_sample_ancestors , num_sample_ancestors )
19241917 tsinfer .verify (samples , final_ts .simplify ())
19251918 ancestors_ts = augmented_ancestors
1919+
1920+
1921+ class TestMissingSampleDataInference (unittest .TestCase ):
1922+ """
1923+ Test that we can infer sites with tskit.MISSING_DATA, using both the PY and C engines
1924+ """
1925+ def test_missing_haplotypes (self ):
1926+ u = tskit .MISSING_DATA
1927+ sites_by_samples = np .array ([
1928+ [u , u , u , u ],
1929+ [u , 1 , 0 , u ],
1930+ [u , 0 , 1 , u ],
1931+ [u , 1 , 1 , u ]
1932+ ], dtype = np .int8 )
1933+ with tsinfer .SampleData () as sample_data :
1934+ for col in range (sites_by_samples .shape [1 ]):
1935+ sample_data .add_site (col , sites_by_samples [:, col ])
1936+ ts = tsinfer .infer (sample_data )
1937+ self .assertTrue (np .all (sites_by_samples == ts .genotype_matrix ().T ))
1938+
1939+ def test_samples_missing_inference_sites (self ):
1940+ u = tskit .MISSING_DATA
1941+ sites_by_samples = np .array ([
1942+ [1 , 0 , 0 , u ],
1943+ [1 , 0 , 0 , u ],
1944+ [0 , 1 , 1 , 1 ],
1945+ [u , u , u , 1 ]], dtype = np .int8 )
1946+ with tsinfer .SampleData () as sample_data :
1947+ for col in range (sites_by_samples .shape [1 ]):
1948+ sample_data .add_site (col , sites_by_samples [:, col ])
1949+ ts = tsinfer .infer (sample_data , simplify = False )
1950+ self .assertTrue (np .all (sites_by_samples == ts .genotype_matrix ().T ))
1951+
1952+ def test_small_truncated_fragments (self ):
1953+ u = tskit .MISSING_DATA
1954+ sites_by_samples = np .array ([
1955+ [u , u , u , 1 , 1 , 0 , 1 , 1 , 1 , u ],
1956+ [u , u , u , 1 , 0 , 0 , 1 , 1 , 0 , u ],
1957+ [u , u , u , 1 , 0 , 1 , 1 , 0 , 1 , u ],
1958+ [u , 0 , 0 , 1 , 0 , 1 , 1 , u , u , u ],
1959+ [u , 0 , 1 , 1 , 0 , 0 , 1 , u , u , u ],
1960+ [u , 1 , 1 , 0 , 0 , 0 , 0 , u , u , u ]
1961+ ], dtype = np .int8 )
1962+ with tsinfer .SampleData () as sample_data :
1963+ for col in range (sites_by_samples .shape [1 ]):
1964+ sample_data .add_site (col , sites_by_samples [:, col ])
1965+ for e in [tsinfer .PY_ENGINE , tsinfer .C_ENGINE ]:
1966+ ancestors = tsinfer .generate_ancestors (sample_data , engine = e )
1967+ ancestors_ts = tsinfer .match_ancestors (
1968+ sample_data , ancestors , engine = e , extended_checks = True )
1969+ ts = tsinfer .match_samples (
1970+ sample_data , ancestors_ts , engine = e , extended_checks = True )
1971+ self .assertTrue (1.0 in ts .breakpoints (True )) # End of lft unknown region
1972+ self .assertTrue (3.0 in ts .breakpoints (True )) # End of 1st unknown batch
1973+ self .assertTrue (7.0 in ts .breakpoints (True )) # Start of 2nd unknown batch
1974+ self .assertTrue (9.0 in ts .breakpoints (True )) # Start of rgt unknown region
1975+ for tree in ts .trees ():
1976+ for s in ts .samples ():
1977+ if tree .interval [1 ] <= 1 :
1978+ self .assertTrue (tree .parent (s ) == tskit .NULL )
1979+ elif tree .interval [1 ] <= 3 :
1980+ if s in [0 , 1 , 2 ]:
1981+ self .assertTrue (tree .parent (s ) == tskit .NULL )
1982+ else :
1983+ self .assertTrue (tree .parent (s ) != tskit .NULL )
1984+ elif tree .interval [0 ] >= 9 :
1985+ self .assertTrue (tree .parent (s ) == tskit .NULL )
1986+ elif tree .interval [0 ] >= 7 :
1987+ if s in [3 , 4 , 5 ]:
1988+ self .assertTrue (tree .parent (s ) == tskit .NULL )
1989+ else :
1990+ self .assertTrue (tree .parent (s ) != tskit .NULL )
1991+
1992+ self .assertTrue (np .all (sites_by_samples == ts .genotype_matrix ().T ))
1993+
1994+ def test_large_truncated_fragments (self ):
1995+ """
1996+ A bit like fragments produced from a sequencer
1997+ """
1998+ def truncate_ts_samples (ts , average_span , random_seed , min_span = 5 ):
1999+ """
2000+ Create a tree sequence that has sample nodes which have been truncated
2001+ so that they span only a small region of the genome. The length of the
2002+ truncated spans is given by a poisson distribution whose mean is average_span
2003+ but which cannot go below a fixed min_span, or above the sequence_length
2004+
2005+ Samples are truncated by removing the edges that connect them to the rest
2006+ of the tree.
2007+ """
2008+ np .random .seed (random_seed )
2009+ # Make a list of (left,right) tuples giving the new limits of each sample
2010+ # Keyed by sample ID.
2011+ keep = {}
2012+ # for simplicity, we pick lengths from a poisson distribution of av 300 bp
2013+ for sample_id , span in zip (
2014+ ts .samples (), np .random .poisson (average_span , ts .num_samples )):
2015+ span = max (span , min_span )
2016+ span = min (span , ts .sequence_length )
2017+ start = np .random .uniform (0 , ts .sequence_length - span )
2018+ keep [sample_id ] = (start , start + span )
2019+
2020+ tables = ts .dump_tables ()
2021+ tables .edges .clear ()
2022+ for e in ts .tables .edges :
2023+ if e .child not in keep :
2024+ left , right = e .left , e .right
2025+ else :
2026+ if e .right <= keep [e .child ][0 ] or e .left >= keep [e .child ][1 ]:
2027+ continue # this edge is outside the focal region
2028+ else :
2029+ left = max (e .left , keep [e .child ][0 ])
2030+ right = min (e .right , keep [e .child ][1 ])
2031+ tables .edges .add_row (left , right , e .parent , e .child )
2032+ return tables .tree_sequence ()
2033+
2034+ ts = msprime .simulate (
2035+ 100 , Ne = 1e2 , length = 400 , recombination_rate = 1e-4 , mutation_rate = 2e-4 ,
2036+ random_seed = 1 )
2037+ truncated_ts = truncate_ts_samples (ts , average_span = 200 , random_seed = 123 )
2038+ sd = tsinfer .SampleData .from_tree_sequence (truncated_ts , use_times = False )
2039+ # Cannot use the normal `simplify` as this removes parts of the TS where only
2040+ # one sample is connected to the root (& the other samples have missing data)
2041+ ts_inferred = tsinfer .infer (sd , simplify = False )
2042+ # Instead we run simplicy explicitly, with `keep_unary=True`
2043+ ts_inferred = ts_inferred .simplify (filter_sites = False , keep_unary = True )
2044+ self .assertTrue (
2045+ np .all (ts_inferred .genotype_matrix () == truncated_ts .genotype_matrix ()))
0 commit comments