@@ -627,54 +627,6 @@ def get_ancestors_tree_sequence(self):
627627 def encode_metadata (self , value ):
628628 return json .dumps (value ).encode ()
629629
630- def locate_mutations_on_tree (self , tree , variant , mutations ):
631- """
632- Find the most parsimonious way to place mutations to define the specified
633- genotypes on the specified tree, and update the mutation table accordingly.
634- """
635- site = variant .site .id
636- samples = np .where (variant .genotypes == 1 )[0 ]
637- num_samples = len (samples )
638- logger .debug ("Locating mutations for site {}; n = {}" .format (site , num_samples ))
639- # Nothing to do if this site is fixed for the ancestral state.
640- if num_samples == 0 :
641- return
642-
643- # count = np.zeros(tree.tree_sequence.num_nodes, dtype=int)
644- count = collections .Counter ()
645- for sample in samples :
646- u = self .sample_ids [sample ]
647- while u != tskit .NULL :
648- count [u ] += 1
649- u = tree .parent (u )
650- # Go up the tree until we find the first node ancestral to all samples.
651- node = self .sample_ids [samples [0 ]]
652- while count [node ] < num_samples :
653- node = tree .parent (node )
654- assert count [node ] == num_samples
655- # Look at the children of this node and put down mutations appropriately.
656- split_children = False
657- for child in tree .children (node ):
658- if count [child ] == 0 :
659- split_children = True
660- break
661- mutation_nodes = [node ]
662- if split_children :
663- mutation_nodes = []
664- for child in tree .children (node ):
665- if count [child ] > 0 :
666- mutation_nodes .append (child )
667- for mutation_node in mutation_nodes :
668- parent_mutation = mutations .add_row (
669- site = site , node = mutation_node , derived_state = variant .alleles [1 ])
670- # Traverse down the tree to find any leaves that do not have this
671- # mutation and insert back mutations.
672- for node in tree .nodes (mutation_node ):
673- if tree .is_sample (node ) and count [node ] == 0 :
674- mutations .add_row (
675- site = site , node = node , derived_state = variant .alleles [0 ],
676- parent = parent_mutation )
677-
678630 def insert_sites (self , tables ):
679631 """
680632 Insert the sites in the sample data that were not marked for inference,
@@ -688,6 +640,7 @@ def insert_sites(self, tables):
688640 _ , node , derived_state , parent = self .tree_sequence_builder .dump_mutations ()
689641 ts = tables .tree_sequence ()
690642 if num_non_inference_sites > 0 :
643+ assert ts .num_edges > 0
691644 logger .info (
692645 "Starting mutation positioning for {} non inference sites" .format (
693646 num_non_inference_sites ))
@@ -696,21 +649,34 @@ def insert_sites(self, tables):
696649 tree = next (trees )
697650 for variant in self .sample_data .variants ():
698651 site = variant .site
652+ predefined_anc_state = site .ancestral_state
699653 while tree .interval [1 ] <= site .position :
700654 tree = next (trees )
701655 assert tree .interval [0 ] <= site .position < tree .interval [1 ]
702656 tables .sites .add_row (
703657 position = site .position ,
704- ancestral_state = site . ancestral_state ,
658+ ancestral_state = predefined_anc_state ,
705659 metadata = self .encode_metadata (site .metadata ))
706660 if site .inference == 1 :
707661 tables .mutations .add_row (
708662 site = site .id , node = node [inferred_site ],
709663 derived_state = variant .alleles [derived_state [inferred_site ]])
710664 inferred_site += 1
711665 else :
712- assert ts .num_edges > 0
713- self .locate_mutations_on_tree (tree , variant , tables .mutations )
666+ inferred_anc_state , mapped_mutations = tree .map_mutations (
667+ variant .genotypes , variant .alleles )
668+ if inferred_anc_state != predefined_anc_state :
669+ # We need to set the ancestral state to that defined in the
670+ # original file
671+ for root_node in tree .roots :
672+ # Add a transition at each root to the mapped value
673+ tables .mutations .add_row (
674+ site = site .id , node = root_node ,
675+ derived_state = inferred_anc_state )
676+ for mutation in mapped_mutations :
677+ tables .mutations .add_row (
678+ site = site .id , node = mutation .node ,
679+ derived_state = mutation .derived_state )
714680 progress_monitor .update ()
715681 else :
716682 # Simple case where all sites are inference sites. We save a lot of time here
0 commit comments