Skip to content

Commit eadaa37

Browse files
committed
Change simplify to use keep_unary=True, rationalise param order, and document
1 parent 08c2d5f commit eadaa37

File tree

3 files changed

+69
-43
lines changed

3 files changed

+69
-43
lines changed

evaluation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,13 @@ def edge_plot(ts, filename):
373373
for tree in ts.trees():
374374
left, right = tree.interval
375375
for u in tree.nodes():
376-
for c in tree.children(u):
377-
lines.append([(left, c), (right, c)])
378-
colours.append(pallete[unrank(tree.samples(c), n)])
376+
children = tree.children(u)
377+
# Don't bother plotting unary nodes, which will all have the same
378+
# samples under them as their next non-unary descendant
379+
if len(children) > 1:
380+
for c in children:
381+
lines.append([(left, c), (right, c)])
382+
colours.append(pallete[unrank(tree.samples(c), n)])
379383

380384
lc = mc.LineCollection(lines, linewidths=2, colors=colours)
381385
fig, ax = plt.subplots()

tsinfer/eval_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,8 @@ def run_perfect_inference(
705705
num_threads=num_threads, extended_checks=extended_checks,
706706
progress_monitor=progress_monitor,
707707
stabilise_node_ordering=time_chunking and not path_compression)
708+
# to compare against the original, we need to remove unary nodes from the inferred TS
709+
inferred_ts = inferred_ts.simplify(keep_unary=False, filter_sites=False)
708710
return ts, inferred_ts
709711

710712

tsinfer/inference.py

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,10 @@ def verify(samples, tree_sequence, progress_monitor=None):
139139

140140

141141
def infer(
142-
sample_data, progress_monitor=None, num_threads=0, path_compression=True,
143-
simplify=True, engine=constants.C_ENGINE):
142+
sample_data, num_threads=0, path_compression=True, simplify=True,
143+
engine=constants.C_ENGINE, progress_monitor=None):
144144
"""
145-
infer(sample_data, num_threads=0)
145+
infer(sample_data, num_threads=0, path_compression=True, simplify=True)
146146
147147
Runs the full :ref:`inference pipeline <sec_inference>` on the specified
148148
:class:`SampleData` instance and returns the inferred
@@ -153,13 +153,19 @@ def infer(
153153
:param int num_threads: The number of worker threads to use in parallelised
154154
sections of the algorithm. If <= 0, do not spawn any threads and
155155
use simpler sequential algorithms (default).
156+
:param bool path_compression: Whether to merge edges that share identical
157+
paths (essentially taking advantage of shared recombination breakpoints).
158+
:param bool simplify: Whether to remove extra tree nodes and edges that are not
159+
on a path between the root and any of the samples. To do so, the final tree
160+
sequence is simplified by appling the :meth:`tskit.TreeSequence.simplify` method
161+
with ``keep_unary`` set to True (default = ``True``).
156162
:returns: The :class:`tskit.TreeSequence` object inferred from the
157163
input sample data.
158164
:rtype: tskit.TreeSequence
159165
"""
160166
ancestor_data = generate_ancestors(
161-
sample_data, engine=engine, progress_monitor=progress_monitor,
162-
num_threads=num_threads)
167+
sample_data, num_threads=num_threads,
168+
engine=engine, progress_monitor=progress_monitor,)
163169
ancestors_ts = match_ancestors(
164170
sample_data, ancestor_data, engine=engine, num_threads=num_threads,
165171
path_compression=path_compression, progress_monitor=progress_monitor)
@@ -171,8 +177,8 @@ def infer(
171177

172178

173179
def generate_ancestors(
174-
sample_data, num_threads=0, progress_monitor=None, engine=constants.C_ENGINE,
175-
**kwargs):
180+
sample_data, num_threads=0, path=None,
181+
engine=constants.C_ENGINE, progress_monitor=None, **kwargs):
176182
"""
177183
generate_ancestors(sample_data, num_threads=0, path=None, **kwargs)
178184
@@ -186,30 +192,32 @@ def generate_ancestors(
186192
187193
ancestor_data = tsinfer.generate_ancestors(sample_data, path="mydata.ancestors")
188194
189-
All other keyword arguments are passed to the :class:`AncestorData` constructor,
195+
Other keyword arguments are passed to the :class:`AncestorData` constructor,
190196
which may be used to control the storage properties of the generated file.
191197
192198
:param SampleData sample_data: The :class:`SampleData` instance that we are
193199
genering putative ancestors from.
194200
:param int num_threads: The number of worker threads to use. If < 1, use a
195201
simpler synchronous algorithm.
202+
:param str path: The path of the file to store the sample data. If None,
203+
the information is stored in memory and not persistent.
196204
:rtype: AncestorData
197205
:returns: The inferred ancestors stored in an :class:`AncestorData` instance.
198206
"""
199207
progress_monitor = _get_progress_monitor(progress_monitor)
200-
with formats.AncestorData(sample_data, **kwargs) as ancestor_data:
208+
with formats.AncestorData(sample_data, path=path, **kwargs) as ancestor_data:
201209
generator = AncestorsGenerator(
202-
sample_data, ancestor_data, progress_monitor, engine=engine,
203-
num_threads=num_threads)
210+
sample_data, ancestor_data, num_threads=num_threads,
211+
engine=engine, progress_monitor=progress_monitor)
204212
generator.add_sites()
205213
generator.run()
206214
ancestor_data.record_provenance("generate-ancestors")
207215
return ancestor_data
208216

209217

210218
def match_ancestors(
211-
sample_data, ancestor_data, progress_monitor=None, num_threads=0,
212-
path_compression=True, extended_checks=False, engine=constants.C_ENGINE):
219+
sample_data, ancestor_data, num_threads=0, path_compression=True,
220+
extended_checks=False, engine=constants.C_ENGINE, progress_monitor=None):
213221
"""
214222
match_ancestors(sample_data, ancestor_data, num_threads=0, path_compression=True)
215223
@@ -225,24 +233,25 @@ def match_ancestors(
225233
a history for.
226234
:param int num_threads: The number of match worker threads to use. If
227235
this is <= 0 then a simpler sequential algorithm is used (default).
228-
:param bool path_compression: Should we try to merge edges that share identical
229-
paths (essentially taking advantage of shared recombination breakpoints)
236+
:param bool path_compression: Whether to merge edges that share identical
237+
paths (essentially taking advantage of shared recombination breakpoints).
230238
:return: The ancestors tree sequence representing the inferred history
231239
of the set of ancestors.
232240
:rtype: tskit.TreeSequence
233241
"""
234242
matcher = AncestorMatcher(
235-
sample_data, ancestor_data, engine=engine,
236-
progress_monitor=progress_monitor, path_compression=path_compression,
237-
num_threads=num_threads, extended_checks=extended_checks)
243+
sample_data, ancestor_data, num_threads=num_threads,
244+
path_compression=path_compression, extended_checks=extended_checks,
245+
engine=engine, progress_monitor=progress_monitor)
238246
return matcher.match_ancestors()
239247

240248

241249
def augment_ancestors(
242-
sample_data, ancestors_ts, indexes, progress_monitor=None, num_threads=0,
243-
path_compression=True, extended_checks=False, engine=constants.C_ENGINE):
250+
sample_data, ancestors_ts, indexes, num_threads=0, path_compression=True,
251+
extended_checks=False, engine=constants.C_ENGINE, progress_monitor=None):
244252
"""
245-
augment_ancestors(sample_data, ancestors_ts, indexes, num_threads=0, simplify=True)
253+
augment_ancestors(sample_data, ancestors_ts, indexes, num_threads=0,
254+
path_compression=True)
246255
247256
Runs the sample matching :ref:`algorithm <sec_inference_match_samples>`
248257
on the specified :class:`SampleData` instance and ancestors tree sequence,
@@ -260,32 +269,34 @@ def augment_ancestors(
260269
tree sequence.
261270
:param int num_threads: The number of match worker threads to use. If
262271
this is <= 0 then a simpler sequential algorithm is used (default).
272+
:param bool path_compression: Whether to merge edges that share identical
273+
paths (essentially taking advantage of shared recombination breakpoints).
263274
:return: The specified ancestors tree sequence augmented with copying
264275
paths for the specified sample.
265276
:rtype: tskit.TreeSequence
266277
"""
267278
manager = SampleMatcher(
268-
sample_data, ancestors_ts, path_compression=path_compression,
269-
engine=engine, progress_monitor=progress_monitor, num_threads=num_threads,
270-
extended_checks=extended_checks)
279+
sample_data, ancestors_ts, num_threads=num_threads,
280+
path_compression=path_compression, extended_checks=extended_checks,
281+
engine=engine, progress_monitor=progress_monitor
282+
)
271283
manager.match_samples(indexes)
272284
ts = manager.get_augmented_ancestors_tree_sequence(indexes)
273285
return ts
274286

275287

276288
def match_samples(
277-
sample_data, ancestors_ts, progress_monitor=None, num_threads=0,
278-
path_compression=True, simplify=True, extended_checks=False,
279-
stabilise_node_ordering=False, engine=constants.C_ENGINE):
289+
sample_data, ancestors_ts, num_threads=0, path_compression=True, simplify=True,
290+
extended_checks=False, stabilise_node_ordering=False, engine=constants.C_ENGINE,
291+
progress_monitor=None):
280292
"""
281-
match_samples(sample_data, ancestors_ts, num_threads=0, simplify=True)
293+
match_samples(sample_data, ancestors_ts, num_threads=0, path_compression=True,
294+
simplify=True)
282295
283296
Runs the sample matching :ref:`algorithm <sec_inference_match_samples>`
284297
on the specified :class:`SampleData` instance and ancestors tree sequence,
285298
returning the final :class:`tskit.TreeSequence` instance containing
286-
the full inferred history for all samples and sites. If ``simplify`` is
287-
True (the default) run :meth:`tskit.TreeSequence.simplify` on the
288-
inferred tree sequence to ensure that it is in canonical form.
299+
the full inferred history for all samples and sites.
289300
290301
:param SampleData sample_data: The :class:`SampleData` instance
291302
representing the input data.
@@ -294,14 +305,20 @@ def match_samples(
294305
history among ancestral ancestral haplotypes.
295306
:param int num_threads: The number of match worker threads to use. If
296307
this is <= 0 then a simpler sequential algorithm is used (default).
308+
:param bool path_compression: Whether to merge edges that share identical
309+
paths (essentially taking advantage of shared recombination breakpoints).
310+
:param bool simplify: Whether to remove extra tree nodes and edges that are not
311+
on a path between the root and any of the samples. To do so, the final tree
312+
sequence is simplified by appling the :meth:`tskit.TreeSequence.simplify` method
313+
with ``keep_unary`` set to True (default = ``True``).
297314
:return: The tree sequence representing the inferred history
298315
of the sample.
299316
:rtype: tskit.TreeSequence
300317
"""
301318
manager = SampleMatcher(
302-
sample_data, ancestors_ts, path_compression=path_compression,
303-
engine=engine, progress_monitor=progress_monitor, num_threads=num_threads,
304-
extended_checks=extended_checks)
319+
sample_data, ancestors_ts, num_threads=num_threads,
320+
path_compression=path_compression, extended_checks=extended_checks,
321+
engine=engine, progress_monitor=progress_monitor)
305322
manager.match_samples()
306323
ts = manager.finalise(
307324
simplify=simplify, stabilise_node_ordering=stabilise_node_ordering)
@@ -313,7 +330,8 @@ class AncestorsGenerator(object):
313330
Manages the process of building ancestors.
314331
"""
315332
def __init__(
316-
self, sample_data, ancestor_data, progress_monitor, engine, num_threads=0):
333+
self, sample_data, ancestor_data, num_threads=0,
334+
engine=constants.C_ENGINE, progress_monitor=None):
317335
self.sample_data = sample_data
318336
self.ancestor_data = ancestor_data
319337
self.progress_monitor = progress_monitor
@@ -453,8 +471,8 @@ def run(self):
453471
class Matcher(object):
454472

455473
def __init__(
456-
self, sample_data, num_threads=1, engine=constants.C_ENGINE,
457-
path_compression=True, progress_monitor=None, extended_checks=False):
474+
self, sample_data, num_threads=1, path_compression=True,
475+
extended_checks=False, engine=constants.C_ENGINE, progress_monitor=None):
458476
self.sample_data = sample_data
459477
self.num_threads = num_threads
460478
self.path_compression = path_compression
@@ -1106,8 +1124,9 @@ def finalise(self, simplify=True, stabilise_node_ordering=False):
11061124
logger.info("Finalising tree sequence")
11071125
ts = self.get_samples_tree_sequence()
11081126
if simplify:
1109-
logger.info("Running simplify on {} nodes and {} edges".format(
1110-
ts.num_nodes, ts.num_edges))
1127+
logger.info(
1128+
"Running simplify(keep_unary=True) on {} nodes and {} edges".format(
1129+
ts.num_nodes, ts.num_edges))
11111130
if stabilise_node_ordering:
11121131
# Ensure all the node times are distinct so that they will have
11131132
# stable IDs after simplifying. This could possibly also be done
@@ -1122,7 +1141,8 @@ def finalise(self, simplify=True, stabilise_node_ordering=False):
11221141
tables.nodes.set_columns(flags=tables.nodes.flags, time=time)
11231142
tables.sort()
11241143
ts = tables.tree_sequence()
1125-
ts = ts.simplify(samples=self.sample_ids, filter_sites=False)
1144+
ts = ts.simplify(
1145+
samples=self.sample_ids, filter_sites=False, keep_unary=True)
11261146
logger.info("Finished simplify; now have {} nodes and {} edges".format(
11271147
ts.num_nodes, ts.num_edges))
11281148
return ts

0 commit comments

Comments
 (0)