Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions docs/inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,10 @@ The final phase of a ``tsinfer`` inference consists of a number steps:
2. As we only use a subset of the available sites for inference
(excluding by default any sites that are fixed or singletons)
we then place mutations on the inferred trees in order to
represent the information at these sites. We currently use a
form of Dollo parsimony to do this. For a given site with
a set of samples with the derived state, first find the MRCA
of these samples, and place a mutation at this node. Then,
for all samples in this subtree that carry the ancestral
state, place a back mutation to the ancestral state directly over
this sample. **Note this approach is suboptimal because there
may be clades of ancestral state samples which would allow us
to encode the data with fewer back mutations.**
represent the information at these sites. This is done using the tskit
`map_mutations <https://tskit.readthedocs.io/en/latest/python-api.html#tskit.Tree.map_mutations>`_.
method.


3. Remove ancestral paths that do not lead to any of the samples by
`simplifying
Expand Down
68 changes: 17 additions & 51 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,54 +627,6 @@ def get_ancestors_tree_sequence(self):
def encode_metadata(self, value):
return json.dumps(value).encode()

def locate_mutations_on_tree(self, tree, variant, mutations):
"""
Find the most parsimonious way to place mutations to define the specified
genotypes on the specified tree, and update the mutation table accordingly.
"""
site = variant.site.id
samples = np.where(variant.genotypes == 1)[0]
num_samples = len(samples)
logger.debug("Locating mutations for site {}; n = {}".format(site, num_samples))
# Nothing to do if this site is fixed for the ancestral state.
if num_samples == 0:
return

# count = np.zeros(tree.tree_sequence.num_nodes, dtype=int)
count = collections.Counter()
for sample in samples:
u = self.sample_ids[sample]
while u != tskit.NULL:
count[u] += 1
u = tree.parent(u)
# Go up the tree until we find the first node ancestral to all samples.
node = self.sample_ids[samples[0]]
while count[node] < num_samples:
node = tree.parent(node)
assert count[node] == num_samples
# Look at the children of this node and put down mutations appropriately.
split_children = False
for child in tree.children(node):
if count[child] == 0:
split_children = True
break
mutation_nodes = [node]
if split_children:
mutation_nodes = []
for child in tree.children(node):
if count[child] > 0:
mutation_nodes.append(child)
for mutation_node in mutation_nodes:
parent_mutation = mutations.add_row(
site=site, node=mutation_node, derived_state=variant.alleles[1])
# Traverse down the tree to find any leaves that do not have this
# mutation and insert back mutations.
for node in tree.nodes(mutation_node):
if tree.is_sample(node) and count[node] == 0:
mutations.add_row(
site=site, node=node, derived_state=variant.alleles[0],
parent=parent_mutation)

def insert_sites(self, tables):
"""
Insert the sites in the sample data that were not marked for inference,
Expand All @@ -688,6 +640,7 @@ def insert_sites(self, tables):
_, node, derived_state, parent = self.tree_sequence_builder.dump_mutations()
ts = tables.tree_sequence()
if num_non_inference_sites > 0:
assert ts.num_edges > 0
logger.info(
"Starting mutation positioning for {} non inference sites".format(
num_non_inference_sites))
Expand All @@ -696,21 +649,34 @@ def insert_sites(self, tables):
tree = next(trees)
for variant in self.sample_data.variants():
site = variant.site
predefined_anc_state = site.ancestral_state
while tree.interval[1] <= site.position:
tree = next(trees)
assert tree.interval[0] <= site.position < tree.interval[1]
tables.sites.add_row(
position=site.position,
ancestral_state=site.ancestral_state,
ancestral_state=predefined_anc_state,
metadata=self.encode_metadata(site.metadata))
if site.inference == 1:
tables.mutations.add_row(
site=site.id, node=node[inferred_site],
derived_state=variant.alleles[derived_state[inferred_site]])
inferred_site += 1
else:
assert ts.num_edges > 0
self.locate_mutations_on_tree(tree, variant, tables.mutations)
inferred_anc_state, mapped_mutations = tree.map_mutations(
variant.genotypes, variant.alleles)
if inferred_anc_state != predefined_anc_state:
# We need to set the ancestral state to that defined in the
# original file
for root_node in tree.roots:
# Add a transition at each root to the mapped value
tables.mutations.add_row(
site=site.id, node=root_node,
derived_state=inferred_anc_state)
for mutation in mapped_mutations:
tables.mutations.add_row(
site=site.id, node=mutation.node,
derived_state=mutation.derived_state)
progress_monitor.update()
else:
# Simple case where all sites are inference sites. We save a lot of time here
Expand Down