Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 49 additions & 10 deletions graphtage/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import logging
from functools import wraps
from typing import Iterable, Iterator, Optional, TypeVar, Union
from typing import Iterable, Iterator, List, Optional, Tuple, TypeVar, Union
from typing_extensions import Protocol

from intervaltree import Interval, IntervalTree
Expand Down Expand Up @@ -384,37 +384,67 @@ def min_bounded(bounds: Iterator[B]) -> B:
def make_distinct(*bounded: Bounded):
"""Ensures that all of the provided bounded arguments are tightened until they are finite and
either definitive or non-overlapping with any of the other arguments."""
import heapq

tree: IntervalTree = IntervalTree()
# Use a max-heap (negative sizes) to find biggest intervals in O(log n)
# Heap entries: (-size, id(interval), interval) - id breaks ties deterministically
size_heap: List[Tuple[int, int, Interval]] = []

for b in bounded:
if not b.bounds().finite:
b.tighten_bounds()
if not b.bounds().finite:
raise ValueError(f"Could not tighten {b!r} to a finite bound")
tree.add(Interval(b.bounds().lower_bound, b.bounds().upper_bound + 1, b))
interval = Interval(b.bounds().lower_bound, b.bounds().upper_bound + 1, b)
tree.add(interval)
size = interval.end - interval.begin
heapq.heappush(size_heap, (-size, id(interval), interval))

# Track which intervals are still valid (not removed from tree)
valid_intervals: set = {id(iv) for iv in tree}

while len(tree) > 1:
# find the biggest interval in the tree
# Pop from heap until we find a valid interval (still in tree)
biggest: Optional[Interval] = None
for m in tree:
m_size = m.end - m.begin
if biggest is None or m_size > biggest.end - biggest.begin:
biggest = m
assert biggest is not None
while size_heap:
neg_size, iv_id, candidate = heapq.heappop(size_heap)
if iv_id in valid_intervals:
# Verify size is still accurate (bounds may have changed)
current_size = candidate.end - candidate.begin
if current_size == -neg_size:
biggest = candidate
break
else:
# Size changed, re-add with correct size
heapq.heappush(size_heap, (-current_size, iv_id, candidate))

if biggest is None:
break

if biggest.data.bounds().definitive():
# This means that all intervals are points, so we are done!
# This means all intervals are points, so we are done!
break

tree.remove(biggest)
valid_intervals.discard(id(biggest))

matching = tree[biggest.begin:biggest.end]
if len(matching) < 1:
# This interval does not intersect any others, so it is distinct
continue
# now find the biggest other interval that intersects with biggest:

# Find the biggest intersecting interval (linear search over smaller set)
second_biggest: Optional[Interval] = None
for m in matching:
m_size = m.end - m.begin
if second_biggest is None or m_size > second_biggest.end - second_biggest.begin:
second_biggest = m

assert second_biggest is not None
tree.remove(second_biggest)
valid_intervals.discard(id(second_biggest))

# Shrink the two biggest intervals until they are distinct
while True:
biggest_bound: Range = biggest.data.bounds()
Expand All @@ -425,17 +455,26 @@ def make_distinct(*bounded: Bounded):
break
biggest.data.tighten_bounds()
second_biggest.data.tighten_bounds()

# Re-add intervals if they still overlap with others
new_interval = Interval(
begin=biggest.data.bounds().lower_bound,
end=biggest.data.bounds().upper_bound + 1,
data=biggest.data
)
if tree.overlaps(new_interval.begin, new_interval.end):
tree.add(new_interval)
valid_intervals.add(id(new_interval))
size = new_interval.end - new_interval.begin
heapq.heappush(size_heap, (-size, id(new_interval), new_interval))

new_interval = Interval(
begin=second_biggest.data.bounds().lower_bound,
end=second_biggest.data.bounds().upper_bound + 1,
data=second_biggest.data
)
if tree.overlaps(new_interval.begin, new_interval.end):
tree.add(new_interval)
valid_intervals.add(id(new_interval))
size = new_interval.end - new_interval.begin
heapq.heappush(size_heap, (-size, id(new_interval), new_interval))
Loading