Skip to content

Commit 6592b5a

Browse files
committed
Use map_mutations from tskit, rather than bespoke locate_mutations_on_tree()
Remove old implementation change docs Tidy comments and variable names
1 parent 2008cb3 commit 6592b5a

File tree

2 files changed

+21
-60
lines changed

2 files changed

+21
-60
lines changed

docs/inference.rst

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,10 @@ The final phase of a ``tsinfer`` inference consists of a number steps:
203203
2. As we only use a subset of the available sites for inference
204204
(excluding by default any sites that are fixed or singletons)
205205
we then place mutations on the inferred trees in order to
206-
represent the information at these sites. We currently use a
207-
form of Dollo parsimony to do this. For a given site with
208-
a set of samples with the derived state, first find the MRCA
209-
of these samples, and place a mutation at this node. Then,
210-
for all samples in this subtree that carry the ancestral
211-
state, place a back mutation to the ancestral state directly over
212-
this sample. **Note this approach is suboptimal because there
213-
may be clades of ancestral state samples which would allow us
214-
to encode the data with fewer back mutations.**
206+
represent the information at these sites. This is done using the tskit
207+
`map_mutations <https://tskit.readthedocs.io/en/latest/python-api.html#tskit.Tree.map_mutations>`_.
208+
method.
209+
215210

216211
3. Remove ancestral paths that do not lead to any of the samples by
217212
`simplifying

tsinfer/inference.py

Lines changed: 17 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -627,54 +627,6 @@ def get_ancestors_tree_sequence(self):
627627
def encode_metadata(self, value):
628628
return json.dumps(value).encode()
629629

630-
def locate_mutations_on_tree(self, tree, variant, mutations):
631-
"""
632-
Find the most parsimonious way to place mutations to define the specified
633-
genotypes on the specified tree, and update the mutation table accordingly.
634-
"""
635-
site = variant.site.id
636-
samples = np.where(variant.genotypes == 1)[0]
637-
num_samples = len(samples)
638-
logger.debug("Locating mutations for site {}; n = {}".format(site, num_samples))
639-
# Nothing to do if this site is fixed for the ancestral state.
640-
if num_samples == 0:
641-
return
642-
643-
# count = np.zeros(tree.tree_sequence.num_nodes, dtype=int)
644-
count = collections.Counter()
645-
for sample in samples:
646-
u = self.sample_ids[sample]
647-
while u != tskit.NULL:
648-
count[u] += 1
649-
u = tree.parent(u)
650-
# Go up the tree until we find the first node ancestral to all samples.
651-
node = self.sample_ids[samples[0]]
652-
while count[node] < num_samples:
653-
node = tree.parent(node)
654-
assert count[node] == num_samples
655-
# Look at the children of this node and put down mutations appropriately.
656-
split_children = False
657-
for child in tree.children(node):
658-
if count[child] == 0:
659-
split_children = True
660-
break
661-
mutation_nodes = [node]
662-
if split_children:
663-
mutation_nodes = []
664-
for child in tree.children(node):
665-
if count[child] > 0:
666-
mutation_nodes.append(child)
667-
for mutation_node in mutation_nodes:
668-
parent_mutation = mutations.add_row(
669-
site=site, node=mutation_node, derived_state=variant.alleles[1])
670-
# Traverse down the tree to find any leaves that do not have this
671-
# mutation and insert back mutations.
672-
for node in tree.nodes(mutation_node):
673-
if tree.is_sample(node) and count[node] == 0:
674-
mutations.add_row(
675-
site=site, node=node, derived_state=variant.alleles[0],
676-
parent=parent_mutation)
677-
678630
def insert_sites(self, tables):
679631
"""
680632
Insert the sites in the sample data that were not marked for inference,
@@ -688,6 +640,7 @@ def insert_sites(self, tables):
688640
_, node, derived_state, parent = self.tree_sequence_builder.dump_mutations()
689641
ts = tables.tree_sequence()
690642
if num_non_inference_sites > 0:
643+
assert ts.num_edges > 0
691644
logger.info(
692645
"Starting mutation positioning for {} non inference sites".format(
693646
num_non_inference_sites))
@@ -696,21 +649,34 @@ def insert_sites(self, tables):
696649
tree = next(trees)
697650
for variant in self.sample_data.variants():
698651
site = variant.site
652+
predefined_anc_state = site.ancestral_state
699653
while tree.interval[1] <= site.position:
700654
tree = next(trees)
701655
assert tree.interval[0] <= site.position < tree.interval[1]
702656
tables.sites.add_row(
703657
position=site.position,
704-
ancestral_state=site.ancestral_state,
658+
ancestral_state=predefined_anc_state,
705659
metadata=self.encode_metadata(site.metadata))
706660
if site.inference == 1:
707661
tables.mutations.add_row(
708662
site=site.id, node=node[inferred_site],
709663
derived_state=variant.alleles[derived_state[inferred_site]])
710664
inferred_site += 1
711665
else:
712-
assert ts.num_edges > 0
713-
self.locate_mutations_on_tree(tree, variant, tables.mutations)
666+
inferred_anc_state, mapped_mutations = tree.map_mutations(
667+
variant.genotypes, variant.alleles)
668+
if inferred_anc_state != predefined_anc_state:
669+
# We need to set the ancestral state to that defined in the
670+
# original file
671+
for root_node in tree.roots:
672+
# Add a transition at each root to the mapped value
673+
tables.mutations.add_row(
674+
site=site.id, node=root_node,
675+
derived_state=inferred_anc_state)
676+
for mutation in mapped_mutations:
677+
tables.mutations.add_row(
678+
site=site.id, node=mutation.node,
679+
derived_state=mutation.derived_state)
714680
progress_monitor.update()
715681
else:
716682
# Simple case where all sites are inference sites. We save a lot of time here

0 commit comments

Comments
 (0)