@@ -615,8 +615,8 @@ def verify_inserted_ancestors(self, ts):
615615 ancestor_data .finalise ()
616616
617617 A = np .full (
618- (ancestor_data .num_sites , ancestor_data .num_ancestors ),
619- tskit . MISSING_DATA , dtype = np .int8 )
618+ (ancestor_data .num_sites , ancestor_data .num_ancestors ), tskit . MISSING_DATA ,
619+ dtype = np .int8 )
620620 start = ancestor_data .ancestors_start [:]
621621 end = ancestor_data .ancestors_end [:]
622622 ancestors = ancestor_data .ancestors_haplotype [:]
@@ -1649,10 +1649,12 @@ class TestExtractAncestors(unittest.TestCase):
16491649 """
16501650 def verify (self , samples ):
16511651 ancestors = tsinfer .generate_ancestors (samples )
1652+ # this ancestors TS has positions mapped only to inference sites
16521653 ancestors_ts_1 = tsinfer .match_ancestors (samples , ancestors )
16531654 ts = tsinfer .match_samples (
16541655 samples , ancestors_ts_1 , path_compression = False , simplify = False )
16551656 t1 = ancestors_ts_1 .dump_tables ()
1657+
16561658 t2 , node_id_map = tsinfer .extract_ancestors (samples , ts )
16571659 self .assertEqual (len (t2 .provenances ), len (t1 .provenances ) + 2 )
16581660 t1 .provenances .clear ()
@@ -1662,14 +1664,6 @@ def verify(self, samples):
16621664 # for now.
16631665 t2 .populations .clear ()
16641666
1665- self .assertEqual (t1 .nodes , t2 .nodes )
1666- self .assertEqual (t1 .edges , t2 .edges )
1667- self .assertEqual (t1 .sites , t2 .sites )
1668- self .assertEqual (t1 .mutations , t2 .mutations )
1669- self .assertEqual (t1 .populations , t2 .populations )
1670- self .assertEqual (t1 .individuals , t2 .individuals )
1671- self .assertEqual (t1 .sites , t2 .sites )
1672-
16731667 self .assertEqual (t1 , t2 )
16741668
16751669 for node in ts .nodes ():
@@ -1683,8 +1677,7 @@ def test_simple_simulation(self):
16831677 def test_non_zero_one_mutations (self ):
16841678 ts = msprime .simulate (10 , recombination_rate = 5 , random_seed = 2 )
16851679 ts = msprime .mutate (
1686- ts , rate = 5 , model = msprime .InfiniteSites (msprime .NUCLEOTIDES ),
1687- random_seed = 15 )
1680+ ts , rate = 2 , model = msprime .InfiniteSites (msprime .NUCLEOTIDES ), random_seed = 15 )
16881681 self .assertGreater (ts .num_mutations , 0 )
16891682 self .verify (tsinfer .SampleData .from_tree_sequence (ts , use_times = False ))
16901683
@@ -1931,3 +1924,130 @@ def verify_example(self, full_subset, samples, ancestors, path_compression):
19311924 self .assertEqual (expected_sample_ancestors , num_sample_ancestors )
19321925 tsinfer .verify (samples , final_ts .simplify ())
19331926 ancestors_ts = augmented_ancestors
1927+
1928+
1929+ class TestMissingSampleDataInference (unittest .TestCase ):
1930+ """
1931+ Test that we can infer sites with tskit.MISSING_DATA, using both the PY and C engines
1932+ """
1933+ def test_missing_haplotypes (self ):
1934+ u = tskit .MISSING_DATA
1935+ sites_by_samples = np .array ([
1936+ [u , u , u , u ],
1937+ [u , 1 , 0 , u ],
1938+ [u , 0 , 1 , u ],
1939+ [u , 1 , 1 , u ]
1940+ ], dtype = np .int8 )
1941+ with tsinfer .SampleData () as sample_data :
1942+ for col in range (sites_by_samples .shape [1 ]):
1943+ sample_data .add_site (col , sites_by_samples [:, col ])
1944+ ts = tsinfer .infer (sample_data )
1945+ self .assertTrue (np .all (sites_by_samples == ts .genotype_matrix ().T ))
1946+
1947+ def test_samples_missing_inference_sites (self ):
1948+ u = tskit .MISSING_DATA
1949+ sites_by_samples = np .array ([
1950+ [1 , 0 , 0 , u ],
1951+ [1 , 0 , 0 , u ],
1952+ [0 , 1 , 1 , 1 ],
1953+ [u , u , u , 1 ]], dtype = np .int8 )
1954+ with tsinfer .SampleData () as sample_data :
1955+ for col in range (sites_by_samples .shape [1 ]):
1956+ sample_data .add_site (col , sites_by_samples [:, col ])
1957+ ts = tsinfer .infer (sample_data , simplify = False )
1958+ self .assertTrue (np .all (sites_by_samples == ts .genotype_matrix ().T ))
1959+
1960+ def test_small_truncated_fragments (self ):
1961+ u = tskit .MISSING_DATA
1962+ sites_by_samples = np .array ([
1963+ [u , u , u , 1 , 1 , 0 , 1 , 1 , 1 , u ],
1964+ [u , u , u , 1 , 0 , 0 , 1 , 1 , 0 , u ],
1965+ [u , u , u , 1 , 0 , 1 , 1 , 0 , 1 , u ],
1966+ [u , 0 , 0 , 1 , 0 , 1 , 1 , u , u , u ],
1967+ [u , 0 , 1 , 1 , 0 , 0 , 1 , u , u , u ],
1968+ [u , 1 , 1 , 0 , 0 , 0 , 0 , u , u , u ]
1969+ ], dtype = np .int8 )
1970+ with tsinfer .SampleData () as sample_data :
1971+ for col in range (sites_by_samples .shape [1 ]):
1972+ sample_data .add_site (col , sites_by_samples [:, col ])
1973+ for e in [tsinfer .PY_ENGINE , tsinfer .C_ENGINE ]:
1974+ ancestors = tsinfer .generate_ancestors (sample_data , engine = e )
1975+ ancestors_ts = tsinfer .match_ancestors (
1976+ sample_data , ancestors , engine = e , extended_checks = True )
1977+ ts = tsinfer .match_samples (
1978+ sample_data , ancestors_ts , engine = e , extended_checks = True )
1979+ self .assertTrue (1.0 in ts .breakpoints (True )) # End of lft unknown region
1980+ self .assertTrue (3.0 in ts .breakpoints (True )) # End of 1st unknown batch
1981+ self .assertTrue (7.0 in ts .breakpoints (True )) # Start of 2nd unknown batch
1982+ self .assertTrue (9.0 in ts .breakpoints (True )) # Start of rgt unknown region
1983+ for tree in ts .trees ():
1984+ for s in ts .samples ():
1985+ if tree .interval [1 ] <= 1 :
1986+ self .assertTrue (tree .parent (s ) == tskit .NULL )
1987+ elif tree .interval [1 ] <= 3 :
1988+ if s in [0 , 1 , 2 ]:
1989+ self .assertTrue (tree .parent (s ) == tskit .NULL )
1990+ else :
1991+ self .assertTrue (tree .parent (s ) != tskit .NULL )
1992+ elif tree .interval [0 ] >= 9 :
1993+ self .assertTrue (tree .parent (s ) == tskit .NULL )
1994+ elif tree .interval [0 ] >= 7 :
1995+ if s in [3 , 4 , 5 ]:
1996+ self .assertTrue (tree .parent (s ) == tskit .NULL )
1997+ else :
1998+ self .assertTrue (tree .parent (s ) != tskit .NULL )
1999+
2000+ self .assertTrue (np .all (sites_by_samples == ts .genotype_matrix ().T ))
2001+
2002+ def test_large_truncated_fragments (self ):
2003+ """
2004+ A bit like fragments produced from a sequencer
2005+ """
2006+ def truncate_ts_samples (ts , average_span , random_seed , min_span = 5 ):
2007+ """
2008+ Create a tree sequence that has sample nodes which have been truncated
2009+ so that they span only a small region of the genome. The length of the
2010+ truncated spans is given by a poisson distribution whose mean is average_span
2011+ but which cannot go below a fixed min_span, or above the sequence_length
2012+
2013+ Samples are truncated by removing the edges that connect them to the rest
2014+ of the tree.
2015+ """
2016+ np .random .seed (random_seed )
2017+ # Make a list of (left,right) tuples giving the new limits of each sample
2018+ # Keyed by sample ID.
2019+ keep = {}
2020+ # for simplicity, we pick lengths from a poisson distribution of av 300 bp
2021+ for sample_id , span in zip (
2022+ ts .samples (), np .random .poisson (average_span , ts .num_samples )):
2023+ span = max (span , min_span )
2024+ span = min (span , ts .sequence_length )
2025+ start = np .random .uniform (0 , ts .sequence_length - span )
2026+ keep [sample_id ] = (start , start + span )
2027+
2028+ tables = ts .dump_tables ()
2029+ tables .edges .clear ()
2030+ for e in ts .tables .edges :
2031+ if e .child not in keep :
2032+ left , right = e .left , e .right
2033+ else :
2034+ if e .right <= keep [e .child ][0 ] or e .left >= keep [e .child ][1 ]:
2035+ continue # this edge is outside the focal region
2036+ else :
2037+ left = max (e .left , keep [e .child ][0 ])
2038+ right = min (e .right , keep [e .child ][1 ])
2039+ tables .edges .add_row (left , right , e .parent , e .child )
2040+ return tables .tree_sequence ()
2041+
2042+ ts = msprime .simulate (
2043+ 100 , Ne = 1e2 , length = 400 , recombination_rate = 1e-4 , mutation_rate = 2e-4 ,
2044+ random_seed = 1 )
2045+ truncated_ts = truncate_ts_samples (ts , average_span = 200 , random_seed = 123 )
2046+ sd = tsinfer .SampleData .from_tree_sequence (truncated_ts , use_times = False )
2047+ # Cannot use the normal `simplify` as this removes parts of the TS where only
2048+ # one sample is connected to the root (& the other samples have missing data)
2049+ ts_inferred = tsinfer .infer (sd , simplify = False )
2050+ # Instead we run simplicy explicitly, with `keep_unary=True`
2051+ ts_inferred = ts_inferred .simplify (filter_sites = False , keep_unary = True )
2052+ self .assertTrue (
2053+ np .all (ts_inferred .genotype_matrix () == truncated_ts .genotype_matrix ()))
0 commit comments