Skip to content

Commit fddef64

Browse files
committed
Use map_mutations from tskit, rather than bespoke locate_mutations_on_tree()
Remove old implementation change docs Tidy comments and variable names
1 parent 800f806 commit fddef64

File tree

2 files changed

+33
-60
lines changed

2 files changed

+33
-60
lines changed

docs/inference.rst

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -188,15 +188,10 @@ The final phase of a ``tsinfer`` inference consists of a number steps:
188188
2. As we only use a subset of the available sites for inference
189189
(excluding by default any sites that are fixed or singletons)
190190
we then place mutations on the inferred trees in order to
191-
represent the information at these sites. We currently use a
192-
form of Dollo parsimony to do this. For a given site with
193-
a set of samples with the derived state, first find the MRCA
194-
of these samples, and place a mutation at this node. Then,
195-
for all samples in this subtree that carry the ancestral
196-
state, place a back mutation to the ancestral state directly over
197-
this sample. **Note this approach is suboptimal because there
198-
may be clades of ancestral state samples which would allow us
199-
to encode the data with fewer back mutations.**
191+
represent the information at these sites. This is done using the tskit
192+
`map_mutations <https://tskit.readthedocs.io/en/latest/python-api.html#tskit.Tree.map_mutations>`_.
193+
method.
194+
200195

201196
3. Reduce the resulting tree sequence to a canonical form by
202197
`simplifying it

tsinfer/inference.py

Lines changed: 29 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -629,54 +629,6 @@ def get_ancestors_tree_sequence(self):
629629
def encode_metadata(self, value):
630630
return json.dumps(value).encode()
631631

632-
def locate_mutations_on_tree(self, tree, variant, mutations):
633-
"""
634-
Find the most parsimonious way to place mutations to define the specified
635-
genotypes on the specified tree, and update the mutation table accordingly.
636-
"""
637-
site = variant.site.id
638-
samples = np.where(variant.genotypes == 1)[0]
639-
num_samples = len(samples)
640-
logger.debug("Locating mutations for site {}; n = {}".format(site, num_samples))
641-
# Nothing to do if this site is fixed for the ancestral state.
642-
if num_samples == 0:
643-
return
644-
645-
# count = np.zeros(tree.tree_sequence.num_nodes, dtype=int)
646-
count = collections.Counter()
647-
for sample in samples:
648-
u = self.sample_ids[sample]
649-
while u != tskit.NULL:
650-
count[u] += 1
651-
u = tree.parent(u)
652-
# Go up the tree until we find the first node ancestral to all samples.
653-
node = self.sample_ids[samples[0]]
654-
while count[node] < num_samples:
655-
node = tree.parent(node)
656-
assert count[node] == num_samples
657-
# Look at the children of this node and put down mutations appropriately.
658-
split_children = False
659-
for child in tree.children(node):
660-
if count[child] == 0:
661-
split_children = True
662-
break
663-
mutation_nodes = [node]
664-
if split_children:
665-
mutation_nodes = []
666-
for child in tree.children(node):
667-
if count[child] > 0:
668-
mutation_nodes.append(child)
669-
for mutation_node in mutation_nodes:
670-
parent_mutation = mutations.add_row(
671-
site=site, node=mutation_node, derived_state=variant.alleles[1])
672-
# Traverse down the tree to find any leaves that do not have this
673-
# mutation and insert back mutations.
674-
for node in tree.nodes(mutation_node):
675-
if tree.is_sample(node) and count[node] == 0:
676-
mutations.add_row(
677-
site=site, node=node, derived_state=variant.alleles[0],
678-
parent=parent_mutation)
679-
680632
def insert_sites(self, tables):
681633
"""
682634
Insert the sites in the sample data that were not marked for inference,
@@ -690,6 +642,7 @@ def insert_sites(self, tables):
690642
_, node, derived_state, parent = self.tree_sequence_builder.dump_mutations()
691643
ts = tables.tree_sequence()
692644
if num_non_inference_sites > 0:
645+
assert ts.num_edges > 0
693646
logger.info(
694647
"Starting mutation positioning for {} non inference sites".format(
695648
num_non_inference_sites))
@@ -698,21 +651,46 @@ def insert_sites(self, tables):
698651
tree = next(trees)
699652
for variant in self.sample_data.variants():
700653
site = variant.site
654+
predefined_anc_state = site.ancestral_state
701655
while tree.interval[1] <= site.position:
702656
tree = next(trees)
703657
assert tree.interval[0] <= site.position < tree.interval[1]
704658
tables.sites.add_row(
705659
position=site.position,
706-
ancestral_state=site.ancestral_state,
660+
ancestral_state=predefined_anc_state,
707661
metadata=self.encode_metadata(site.metadata))
708662
if site.inference == 1:
709663
tables.mutations.add_row(
710664
site=site.id, node=node[inferred_site],
711665
derived_state=variant.alleles[derived_state[inferred_site]])
712666
inferred_site += 1
713667
else:
714-
assert ts.num_edges > 0
715-
self.locate_mutations_on_tree(tree, variant, tables.mutations)
668+
if np.all(variant.genotypes == tskit.MISSING_DATA):
669+
# Map_mutations has to have at least 1 non-missing value to work
670+
inferred_anc_state = tskit.MISSING_DATA
671+
mapped_mutations = []
672+
else:
673+
inferred_anc_state, mapped_mutations = tree.map_mutations(
674+
variant.genotypes, variant.alleles + (tskit.MISSING_DATA, ))
675+
if (predefined_anc_state != tskit.MISSING_DATA and
676+
inferred_anc_state != predefined_anc_state):
677+
# Non-inference sites whose ancestral state was defined as
678+
# missing have their ancestral state inferred. Otherwise we need
679+
# to set the ancestral state to that defined in the original file
680+
sample_missing = set(
681+
ts.samples()[variant.genotypes == tskit.MISSING_DATA])
682+
for root_node in tree.roots:
683+
# Add a transition at each root to the mapped value
684+
if tree.is_leaf(root_node) and root_node in sample_missing:
685+
# Except isolated tips (automatically flagged as missing)
686+
continue
687+
tables.mutations.add_row(
688+
site=site.id, node=root_node,
689+
derived_state=inferred_anc_state)
690+
for mutation in mapped_mutations:
691+
tables.mutations.add_row(
692+
site=site.id, node=mutation.node,
693+
derived_state=mutation.derived_state)
716694
progress_monitor.update()
717695
else:
718696
# Simple case where all sites are inference sites. We save a lot of time here

0 commit comments

Comments
 (0)