Skip to content

Don't load all ancestors when truncating #811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ In development
containing sequences with substantial amounts of error.
({pr}`761`, {user}`jeromekelleher`)

- `truncate_ancestors` no longer requires loading all the ancestors into RAM.
({pr}`811`, {user}`benjeffery`)

## [0.3.0] - 2022-10-25

**Features**
Expand Down
4 changes: 3 additions & 1 deletion tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2420,7 +2420,9 @@ def test_ancestors_truncated_length(self):
ancestors = tsinfer.generate_ancestors(sample_data)
lower_limit = 0.4
upper_limit = 0.6
trunc_anc = ancestors.truncate_ancestors(lower_limit, upper_limit, 1)
trunc_anc = ancestors.truncate_ancestors(
lower_limit, upper_limit, 1, buffer_length=1
)
original_lengths = ancestors.ancestors_length[:]
trunc_lengths = trunc_anc.ancestors_length[:]
# Check that ancestors older than upper_limit have been cut down
Expand Down
67 changes: 66 additions & 1 deletion tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def verify_data_round_trip(
params = [(0.4, 0.6, 1), (0, 1, 10)]
for param in params:
truncated_ancestors = ancestors.truncate_ancestors(
param[0], param[1], param[2]
param[0], param[1], param[2], buffer_length=2
)
engines = [tsinfer.C_ENGINE, tsinfer.PY_ENGINE]
for engine in engines:
Expand All @@ -568,6 +568,71 @@ def verify_data_round_trip(
)


class TestTruncateAncestorsRoundTripFromDisk(TestRoundTrip):
"""
Tests that we can correctly round trip data when we truncate ancestral haplotypes
which have come from disk
"""

def verify_data_round_trip(
self,
genotypes,
positions,
alleles=None,
sequence_length=None,
site_times=None,
individual_times=None,
ancestral_alleles=None,
):
sample_data = self.create_sample_data(
genotypes,
positions,
alleles,
sequence_length,
site_times,
individual_times,
ancestral_alleles,
)
with tempfile.TemporaryDirectory() as d:
tsinfer.generate_ancestors(sample_data, path=d + "ancestors.tsi")
ancestors = tsinfer.AncestorData.load(d + "ancestors.tsi")
time = np.sort(ancestors.ancestors_time[:])
# Some tests produce an AncestorData file with no ancestors
if len(time) > 0:
lower_bound = np.min(time)
upper_bound = np.max(time)
midpoint = np.median(time)
params = [
(lower_bound, upper_bound, 0.1),
(lower_bound, upper_bound, 1),
(midpoint, midpoint + (midpoint / 2), 1),
]
else:
params = [(0.4, 0.6, 1), (0, 1, 10)]
for param in params:
truncated_ancestors = ancestors.truncate_ancestors(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will ancestors.truncate_ancestors(*param, buffer_length=2) work here ?

*param, buffer_length=2
)
engines = [tsinfer.C_ENGINE, tsinfer.PY_ENGINE]
for engine in engines:
ancestors_ts = tsinfer.match_ancestors(
sample_data, truncated_ancestors, engine=engine
)
ts = tsinfer.match_samples(
sample_data,
ancestors_ts,
engine=engine,
)
self.assert_lossless(
ts,
genotypes,
positions,
alleles,
sample_data.sequence_length,
ancestral_alleles,
)


class TestSparseAncestorsRoundTrip(TestRoundTrip):
"""
Tests that we correctly round trip data when we generate the sparsest possible
Expand Down
80 changes: 60 additions & 20 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3008,6 +3008,7 @@ def truncate_ancestors(
lower_time_bound,
upper_time_bound,
length_multiplier=2,
buffer_length=1000,
**kwargs,
):
"""
Expand Down Expand Up @@ -3047,6 +3048,8 @@ def truncate_ancestors(
and ``uppper_time_bound`` (exclusive), i.e.
if the longest ancestor in the interval is 1 megabase, a
``length_multiplier`` of 2 creates a maximum length of 2 megabases.
:param int buffer_length: The number of changed ancestors to buffer before
writing to disk.
:param \\**kwargs: Further arguments passed to the :func:`AncestorData.copy`
when creating the new :class:`AncestorData` instance which will be returned.

Expand All @@ -3065,15 +3068,10 @@ def truncate_ancestors(
raise ValueError("Upper bound must be >= lower bound")

position = self.sites_position[:]
start = self.ancestors_start[:]
end = self.ancestors_end[:]
time = self.ancestors_time[:]
focal_sites = self.ancestors_focal_sites[:]
haplotypes = self.ancestors_full_haplotype[:]
if upper_time_bound > np.max(time) or lower_time_bound > np.max(time):
raise ValueError("Time bounds cannot be greater than older ancestor")

truncated = self.copy(**kwargs)
anc_in_bound = np.logical_and(
time >= lower_time_bound,
time < upper_time_bound,
Expand All @@ -3082,7 +3080,50 @@ def truncate_ancestors(
raise ValueError("No ancestors in time bound")
max_length = length_multiplier * np.max(self.ancestors_length[:][anc_in_bound])

for anc in self.ancestors():
truncated = self.copy(**kwargs)

# Create a buffer of buffer_length ancestors with their indexes
index_buffer = np.zeros(buffer_length, dtype=np.int32)
start_buffer = np.zeros(buffer_length, dtype=self.ancestors_start.dtype)
end_buffer = np.zeros(buffer_length, dtype=self.ancestors_end.dtype)
time_buffer = np.zeros(buffer_length, dtype=self.ancestors_time.dtype)
focal_sites_buffer = np.zeros(
buffer_length, dtype=self.ancestors_focal_sites.dtype
)
haplotype_buffer = np.full(
(self.ancestors_full_haplotype.shape[0], buffer_length, 1),
tskit.MISSING_DATA,
dtype=self.ancestors_full_haplotype.dtype,
)
buffer_pos = 0

def flush_buffers(buffer_pos):
# As we find ancestors that need to be truncated, we write them to the
# buffers, with index_buffer storing the index of the ancestor in the
# original AncestorData file. We can use then specify this index array to
# zarr to just write those changed lines to the new AncestorData file.
truncated.ancestors_start.set_orthogonal_selection(
index_buffer[:buffer_pos], start_buffer[:buffer_pos]
)
truncated.ancestors_end.set_orthogonal_selection(
index_buffer[:buffer_pos], end_buffer[:buffer_pos]
)
truncated.ancestors_time.set_orthogonal_selection(
index_buffer[:buffer_pos], time_buffer[:buffer_pos]
)
truncated.ancestors_focal_sites.set_orthogonal_selection(
index_buffer[:buffer_pos], focal_sites_buffer[:buffer_pos]
)
truncated.ancestors_full_haplotype.set_orthogonal_selection(
(slice(None), index_buffer[:buffer_pos]),
haplotype_buffer[:, :buffer_pos],
)
truncated.ancestors_full_haplotype_mask.set_orthogonal_selection(
(slice(None), index_buffer[:buffer_pos]),
haplotype_buffer[:, :buffer_pos] == tskit.MISSING_DATA,
)

for anc_index, anc in enumerate(self.ancestors()):
if anc.time >= upper_time_bound and len(anc.focal_sites) > 0:
if position[anc.end - 1] - position[anc.start] > max_length:
left_focal_pos = position[np.min(anc.focal_sites)]
Expand All @@ -3104,21 +3145,20 @@ def truncate_ancestors(
f"Truncating ancestor {anc.id} at time {anc.time}"
"Original length {original_length}. New length {new_length}"
)
start[anc.id] = insert_pos_start
end[anc.id] = insert_pos_end
time[anc.id] = anc.time
focal_sites[anc.id] = anc.focal_sites
haplotypes[:, anc.id] = tskit.MISSING_DATA
haplotypes[
insert_pos_start:insert_pos_end, anc.id, 0
index_buffer[buffer_pos] = anc_index
start_buffer[buffer_pos] = insert_pos_start
end_buffer[buffer_pos] = insert_pos_end
time_buffer[buffer_pos] = anc.time
focal_sites_buffer[buffer_pos] = anc.focal_sites
haplotype_buffer[
insert_pos_start:insert_pos_end, buffer_pos, 0
] = anc.full_haplotype[insert_pos_start:insert_pos_end]
# TODO - record truncation in ancestors' metadata when supported
truncated.ancestors_start[:] = start
truncated.ancestors_end[:] = end
truncated.ancestors_time[:] = time
truncated.ancestors_focal_sites[:] = focal_sites
truncated.ancestors_full_haplotype[:] = haplotypes
truncated.ancestors_full_haplotype_mask[:] = haplotypes == tskit.MISSING_DATA
buffer_pos += 1
if buffer_pos == buffer_length:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A local function would help here,

def flush_buffer(length):
       truncated.ancestors_start.set_orthogonal_selection(
              index_buffer[:length], start_buffer[:length]
         )
        # etc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow how this is working, so a few comments on how flush_buffer works and would be helpful here.

flush_buffers(buffer_length)
buffer_pos = 0
if buffer_pos > 0:
flush_buffers(buffer_pos)
truncated.record_provenance(command="truncate_ancestors")
truncated.finalise()

Expand Down